pytorch编写的用于自然语言处理分类或顺序任务的深层神经网络。

dnnnlp的Python项目详细描述


pytorch-深层神经网络-自然语言处理

KZXuan的1.0版

包含cnn、rnn和transformer层,以及pytorch为nlp中的分类任务实现的模型。

  • 新设计的模块。
  • 减少使用复杂度。
  • 使用mask作为序列长度标识符。
  • 多GPU并行网格搜索。

即将推出:新的序列标签支持。


依赖性

python 3.5+&pytorch 1.2.0+


API Document


超参数

NameTypeDefaultDescription
n_gpuint1The number of GPUs (0 means no GPU acceleration).
space_turboboolTrueAccelerate with more GPU memories.
data_shuffleboolTureDisrupt data for training.
emb_typestrNoneEmbedding modes contain None, 'const' or 'variable'.
emb_dimint300Embedding dimension (or feature dimension).
n_classint2Number of target classes.
n_hiddenint50Number of hidden nodes, or output channels of CNN.
learning_ratefloat0.01Learning rate.
l2_regfloat1e-6L2 regular.
batch_sizeint128Number of samples for one batch.
iter_timesint30Number of iterations.
display_stepint2The number of iterations between each output of the result.
drop_probfloat0.1Dropout ratio.
eval_metricstr'accuracy'Evaluation metrics contain 'accuracy', 'macro', 'class1', etc.

使用量

# import our modulesfromdnnnlp.modelimportRNNModelfromdnnnlp.execimportdefault_args,Classify# load the embedding matrixemb_mat=np.array((-1,300))# load the train datatrain_x=np.array((800,50))train_y=np.array((800,))train_mask=np.array((800,50))# load the test datatest_x=np.array((200,50))test_y=np.array((200,))test_mask=np.array((200,50))# get the default argumentsargs=default_args()# modify part of the argumentsargs.space_turbo=Falseargs.n_hidden=100args.batch_size=32
  • 分类
# initilize a modelmodel=RNNModel(args,emb_mat,bi_direction=False,rnn_type='GRU',use_attention=True)# initilize a classifiernn=Classify(model,args,train_x,train_y,train_mask,test_x,test_y,test_mask)# do training and testingevals=nn.train_test(device_id=0)
  • 跑几次,得到平均分。
# initilize a modelmodel=CNNModel(args,emb_mat,kernel_widths=[2,3,4])# initilize a classifiernn=Classify(model,args,train_x,train_y,train_mask)# run the model several timesavg_evals=average_several_run(nn.cross_validation,args,n_times=8,n_paral=4,fold=5)
  • 参数的网格搜索。
# initilize a modelmodel=TransformerModel(args,n_layer=12,n_head=8)# initilize a classifiernn=Classify(model,args,train_x,train_y,train_mask,test_x,test_y,test_mask)# set searching paramsparams_search={'learning_rate':[0.1,0.01],'n_hidden':[50,100]}# run grid searchmax_evals=grid_search(nn,nn.train_test,args,params_search)

历史

1.0版

  • 将项目dnn重命名为dnnnlp
  • 删除文件base,添加文件utils
  • 优化并重命名SoftmaxLayerSoftAttentionLayer
  • 重写并重命名EmbeddingLayerCNNLayerRNNLayer
  • 重写MultiheadAttentionLayer:基于nn.MultiheadAttention的打包注意层。
  • 重写TransformerLayer:支持新的MultiheadAttentionLayer
  • 优化并重命名CNNModelRNNModelTransformerModel
  • 优化并重命名Classify:一个高度适用的分类执行模块。
  • 重写average_several_rungrid_search:支持多GPU并行。
  • 支持Pythorch 1.2.0。

0.12版

  • 更新RNN_layer:完全支持tanh、lstm和gru。
  • 修复某些掩码操作中的错误。
  • 支持Pythorch 1.1.0。

旧版本0.12.3

0.11版

  • 提供一种使用更多GPU存储器的加速方法。
  • 解决数据读取异常导致的内存消耗问题。
  • 添加multi_head_attention_layer:包装变压器的多头注意事项。
  • 添加Transformer_layerTransformer_model:封装我们自己编写的变压器层和模型。
  • 支持培训数据中断。

0.10版

  • 把代码分成四个文件:baselayermodelexec
  • 添加CNN_layerCNN_model:包装cnn层和模型。
  • 支持每种型号多个GPU并行。

0.9版

  • 解决输出格式问题。
  • 修正LSTM_classify交叉验证部分的统计错误。
  • 重命名:LSTM_modelRNN_layerself_attentionself_attention_layer
  • 添加softmax_layer:一个完全连接的封装层。

0.8版

  • 调整LSTM_classify中函数的适用性,以避免在LSTM_sequence中重写。
  • 优化参数传递方式。
  • 更完善的评估机制。

0.7版

  • 添加LSTM_sequence:用于LSTM_model的序列标记模块。
  • 解决了层次分类中的NaN值问题。
  • 支持Pythorch 1.0.0。

0.6版

  • 更新LSTM_classify:支持分级分类。
  • GRU_model合并到LSTM_model中。
  • 适应CPU操作。

0.5版

  • 拆分LSTM_classify的运行部分以减少自定义模型的重写。
  • 为可视化输出添加控件。
  • 创建函数average_several_run:支持经过多次训练和测试获得平均分。
  • 创建函数grid_search:支持参数的网格搜索。

0.4版

  • addGRU_model:基于nn.GRU的包装gru模型。
  • 支持L2常规。

0.3版

  • addself_attention:提供注意机制支持。
  • updateLSTM_classify:适应复杂的自定义模型。

0.2版

  • 支持嵌入模式选择。
  • nn.Dropout的默认用法。
  • 创建函数default_args以提供默认超参数。

0.1版

  • 项目初始化dnn:基于pytorch 0.4.1。
  • 添加LSTM_model:基于nn.LSTM的包装lstm模型。
  • addLSTM_classify:lstm模型的一个分类模块,支持列车测试和corss验证。

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java从类对象访问静态变量   java无法在三星A5上使用Toast(2016年)   java处理阻止图像在选择其他图像时消失   java Install4j Linux应用程序   swing在jpanel form java上具有暂停/恢复按钮   java Log4J登录年份文件夹   java XmlPullParser资源管理   JavaGoogleCloudEndpoints方法总是导致NullPointerException,为什么?   java JSON到带有POJO和Enum的Spring控制器   java制作自定义名称生成器?   java仅在设备屏幕的特定部分显示google地图多段线   java图像没有重新绘制,只是相乘   java如何将格式化字符串转换为浮点?   java无法提前很长时间安排TimerTask   当引用函数::和时,java Intellij IDEA无法解析“和”函数接口方法   java结束了dowhile循环   java Spring路径变量绑定   log4j API中FileAppender中的java问题   java使用QMessageBox从选项列表中进行选择