基于pytorch的通用多特征分类库
nymph的Python项目详细描述
nymph
基于Pytorch的多特征分类框架
概述
基于Pytorch的多特征序列标注和普通分类框架,包装的还算可以。可以直接照搬demo,拿csv文件去训练预测。
功能
- 多特征分类(特征包括字符型、数值型,其中字符型最好是单个词而非词组或句子)
- 输出详细分类详情
原理
- 预处理:对各列非数值类数据分别构建词表并使用Embedding获得低维稠密向量,对数值类数据进行标准化,然后拼接获得各行对应向量
- 模型:
- 普通分类:全连接神经网络,NormClassifier(具体效果看特征)
- 序列标注:Bi-LSTM-CRF,SeqClassifier(效果较好)
- 预测:使用sklearn获取f1分数,并且获得各类别分类详情
安装
使用如下命令进行安装
pip install -U nymph
使用示例
训练数据
数据可见test.csv
如图:
普通分类
训练模型
源码如下,具体可参见train_demo_by_norm.py:
# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportNormDataset,split_datasetfromnymph.modulesimportNormClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves'if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=NormClassifier()classifier.init_data_processor(data,target_name='label')# 构建数据集norm_ds=NormDataset(data)train_ratio=0.7dev_ratio=0.2test_ratio=0.1train_ds,dev_ds,test_ds=split_dataset(norm_ds,(train_ratio,dev_ratio,test_ratio))# 训练模型# classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)classifier.train(train_set=norm_ds,dev_set=norm_ds,save_path=save_path)# 测试模型test_score=classifier.score(norm_ds)print('test_score',test_score)# 预测模型pred=classifier.predict(norm_ds)print(pred)
训练结果
终端输出
预测模型
源码如下,具体可参见predict_demo_by_norm.py
# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportNormDataset,split_datasetfromnymph.modulesimportNormClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves'if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=NormClassifier()# 加载分类器classifier.load(save_path)# 构建数据集norm_ds=NormDataset(data)# 预测模型pred=classifier.predict(norm_ds)print(pred)# 获取各类别分类结果,并保存信息至文件中classifier.report(norm_ds,'report.csv')# 对数据进行预测,并将数据和预测结果写入到新的文件中classifier.summary(norm_ds,'summary.csv')
预测结果
如图:
report.csv
内容
summary.csv
内容
序列标注
训练模型
源码如下,具体可参见train_demo_by_seq.py:
# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportSeqDataset,split_datasetfromnymph.modulesimportSeqClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves_seq'defsplit_fn(dataset:list):returnlist(range(len(dataset)+1))if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=SeqClassifier()classifier.init_data_processor(data,target_name='label')# 构建数据集norm_ds=SeqDataset(data,split_fn=split_fn,min_len=4)train_ratio=0.7dev_ratio=0.2test_ratio=0.1train_ds,dev_ds,test_ds=split_dataset(norm_ds,(train_ratio,dev_ratio,test_ratio))# 训练模型# classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)classifier.train(train_set=norm_ds,dev_set=norm_ds,save_path=save_path)# 测试模型test_score=classifier.score(norm_ds)print('test_score',test_score)# 预测模型pred=classifier.predict(norm_ds)print(pred)
训练结果
终端输出
预测模型
源码如下,具体可参见predict_demo_by_seq.py
# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportSeqDataset,split_datasetfromnymph.modulesimportSeqClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves_seq'defsplit_fn(dataset:list):returnlist(range(len(dataset)+1))if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=SeqClassifier()# 加载分类器classifier.load(save_path)# 构建数据集seq_ds=SeqDataset(data,split_fn=split_fn,min_len=4)# 预测模型pred=classifier.predict(seq_ds)print(pred)# 获取各类别分类结果,并保存信息至文件中classifier.report(seq_ds,'seq_demo_report.csv')# 对数据进行预测,并将数据和预测结果写入到新的文件中classifier.summary(seq_ds,'seq_demo_summary.csv')
如图:
seq_demo_report.csv
内容
seq_demo_summary.csv
内容
更新历史
- 0.1.0: 初始化项目,增加全连接模型
- 0.2.0: 增加序列标注模型,大幅重构项目结构与内部实现代码
- 0.2.1: 更新代码,使GPU和CPU下同时可用
- 0.2.2: 增加将训练过程的loss和f1值写入到TensorBoard中
- 0.2.3: 增加Norm Classifier的TensorBoard功能
参考
- 项目
标签: