基于pytorch的通用多特征分类库

nymph的Python项目详细描述


nymph

基于Pytorch的多特征分类框架

概述

基于Pytorch的多特征序列标注和普通分类框架,包装的还算可以。可以直接照搬demo,拿csv文件去训练预测。

功能

  • 多特征分类(特征包括字符型、数值型,其中字符型最好是单个词而非词组或句子)
  • 输出详细分类详情

原理

  • 预处理:对各列非数值类数据分别构建词表并使用Embedding获得低维稠密向量,对数值类数据进行标准化,然后拼接获得各行对应向量
  • 模型:
    • 普通分类:全连接神经网络,NormClassifier(具体效果看特征)
    • 序列标注:Bi-LSTM-CRF,SeqClassifier(效果较好)
  • 预测:使用sklearn获取f1分数,并且获得各类别分类详情

安装

使用如下命令进行安装

pip install -U nymph

使用示例

训练数据

数据可见test.csv

如图:

test_data

普通分类

训练模型

源码如下,具体可参见train_demo_by_norm.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportNormDataset,split_datasetfromnymph.modulesimportNormClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves'if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=NormClassifier()classifier.init_data_processor(data,target_name='label')# 构建数据集norm_ds=NormDataset(data)train_ratio=0.7dev_ratio=0.2test_ratio=0.1train_ds,dev_ds,test_ds=split_dataset(norm_ds,(train_ratio,dev_ratio,test_ratio))# 训练模型# classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)classifier.train(train_set=norm_ds,dev_set=norm_ds,save_path=save_path)# 测试模型test_score=classifier.score(norm_ds)print('test_score',test_score)# 预测模型pred=classifier.predict(norm_ds)print(pred)
训练结果

终端输出

train_demo_by_norm_result

预测模型

源码如下,具体可参见predict_demo_by_norm.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportNormDataset,split_datasetfromnymph.modulesimportNormClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves'if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=NormClassifier()# 加载分类器classifier.load(save_path)# 构建数据集norm_ds=NormDataset(data)# 预测模型pred=classifier.predict(norm_ds)print(pred)# 获取各类别分类结果,并保存信息至文件中classifier.report(norm_ds,'report.csv')# 对数据进行预测,并将数据和预测结果写入到新的文件中classifier.summary(norm_ds,'summary.csv')
预测结果

如图:predict_demo_by_norm_result

report.csv内容

report

summary.csv内容

summary

序列标注

训练模型

源码如下,具体可参见train_demo_by_seq.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportSeqDataset,split_datasetfromnymph.modulesimportSeqClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves_seq'defsplit_fn(dataset:list):returnlist(range(len(dataset)+1))if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=SeqClassifier()classifier.init_data_processor(data,target_name='label')# 构建数据集norm_ds=SeqDataset(data,split_fn=split_fn,min_len=4)train_ratio=0.7dev_ratio=0.2test_ratio=0.1train_ds,dev_ds,test_ds=split_dataset(norm_ds,(train_ratio,dev_ratio,test_ratio))# 训练模型# classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)classifier.train(train_set=norm_ds,dev_set=norm_ds,save_path=save_path)# 测试模型test_score=classifier.score(norm_ds)print('test_score',test_score)# 预测模型pred=classifier.predict(norm_ds)print(pred)
训练结果

终端输出

train_demo_by_seq_result

预测模型

源码如下,具体可参见predict_demo_by_seq.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportSeqDataset,split_datasetfromnymph.modulesimportSeqClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves_seq'defsplit_fn(dataset:list):returnlist(range(len(dataset)+1))if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=SeqClassifier()# 加载分类器classifier.load(save_path)# 构建数据集seq_ds=SeqDataset(data,split_fn=split_fn,min_len=4)# 预测模型pred=classifier.predict(seq_ds)print(pred)# 获取各类别分类结果,并保存信息至文件中classifier.report(seq_ds,'seq_demo_report.csv')# 对数据进行预测,并将数据和预测结果写入到新的文件中classifier.summary(seq_ds,'seq_demo_summary.csv')

如图:predict_demo_by_seq_result

seq_demo_report.csv内容

seq_demo_report

seq_demo_summary.csv内容

seq_demo_summary

更新历史

  • 0.1.0: 初始化项目,增加全连接模型
  • 0.2.0: 增加序列标注模型,大幅重构项目结构与内部实现代码
  • 0.2.1: 更新代码,使GPU和CPU下同时可用
  • 0.2.2: 增加将训练过程的loss和f1值写入到TensorBoard中
  • 0.2.3: 增加Norm Classifier的TensorBoard功能

参考

  1. python - Sorting list based on values from another list? - Stack Overflow

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java窗口。位置和窗口。公开问题   java如何从存储在ArrayList<Node>中的动态生成的文本字段中获取文本?   java如何立即关闭InputStream?   如何重新启动Java程序以激活环境变量   java搜索字符串是否相差一个字符   java CFB模式输出与CTR输出相同;我做错什么了吗?   java如何在javaFX中将实例化对象添加到Stage   java如何在jtextarea上打印来自不同类的文本消息   java以编程方式确定IOException的原因?   限制Java NIO通道(文件或socket)中的可用内容   javajaxb与JDOM:是否可以使用JAXB更新xml文件   批处理文件到java测试   JavaFX:stage的作用是什么。可设置大小(false)是否会导致额外的页边距?   java有没有办法告诉IntelliJ按需堆叠参数?   java Seam会话范围的组件在下一个请求中消失   java Google Web Toolkit对开发复杂的java脚本有用吗?   安卓 studio java ArrayList正在检索最高值   java为什么递归地用随机数填充LinkedList时会出现StackOverflowException?