基于掩蔽序列训练的非监督NMT算法

unmass的Python项目详细描述


质量

MASS是一种新的基于序列到序列的语言生成任务的预训练方法。 它在编码器中随机屏蔽一个句子片段,然后在解码器中预测它。在

img

MASS可以应用于跨语言任务,如神经机器翻译(NMT), 以及诸如文本摘要之类的单语任务。 当前的代码库支持无监督的NMT(基于XLM实现)。在

作者:原开发者/研究人员:

facebookresearch/XLM
   |---microsoft/MASS
        |---<this>

无监督的NMT

无监督神经机器翻译只使用单语数据来训练模型。 在大规模预训练期间,源语言和目标语言在一个模型中进行预训练,使用 相应的语言嵌入来区分语言。 在大规模微调过程中,采用反向平移的方法对无监督模型进行训练。 我们提供预先培训和微调的模型:

LanguagesPre-trained ModelFine-tuned ModelBPE codesVocabulary
EN - FRMODELMODELBPE codesVocabulary
EN - DEMODELMODELBPE codesVocabulary
En - ROMODELMODELBPE_codesVocabulary

我们还在准备更多语言对的更大模型,并将在将来发布。在

依赖性

目前我们基于XLM的代码库实现了无监督NMT的MASS。相关关系如下:

  • Python3
  • NumPy公司
  • Pythorch(版本0.4和1.0)
  • fastBPE(用于BPE代码)
  • 摩西(象征化)
  • Apex(用于fp16培训)

数据就绪

我们在XLM中使用相同的BPE代码和词汇表。这里我们以英语法语为例。在

^{pr2}$

培训前:

python train.py                                      \
--exp_name unsupMT_enfr                              \
--data_path ./data/processed/en-fr/                  \
--lgs 'en-fr'                                        \
--mass_steps 'en,fr'                                 \
--encoder_only false                                 \
--emb_dim 1024                                       \
--n_layers 6                                         \
--n_heads 8                                          \
--dropout 0.1                                        \
--attention_dropout 0.1                              \
--gelu_activation true                               \
--tokens_per_batch 3000                              \
--optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
--epoch_size 200000                                  \
--max_epoch 100                                      \
--eval_bleu true                                     \
--word_mass 0.5                                      \
--min_len 5                                          \

在预培训过程中,即使没有任何反向翻译,您也可以观察到模型可以获得一些初始BLEU分数:

epoch -> 4
valid_fr-en_mt_bleu -> 10.55
valid_en-fr_mt_bleu ->  7.81
test_fr-en_mt_bleu  -> 11.72
test_en-fr_mt_bleu  ->  8.80

分布式培训

在同一个节点上使用多个gpu,例如3个gpu

export NGPU=3; CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=$NGPU train.py [...args]

要跨many nodes使用multiple gpu,请使用Slurm请求多节点作业并启动上面的命令。 代码会自动检测SLURM\ux环境变量来分发培训。在

微调

在预培训之后,我们使用反向翻译来微调无监督机器翻译的预培训模型:

MODEL=mass_enfr_1024.pth

python train.py \
  --exp_name unsupMT_enfr                              \
  --data_path ./data/processed/en-fr/                  \
  --lgs 'en-fr'                                        \
  --bt_steps 'en-fr-en,fr-en-fr'                       \
  --encoder_only false                                 \
  --emb_dim 1024                                       \
  --n_layers 6                                         \
  --n_heads 8                                          \
  --dropout 0.1                                        \
  --attention_dropout 0.1                              \
  --gelu_activation true                               \
  --tokens_per_batch 2000                              \
  --batch_size 32	                                     \
  --bptt 256                                           \
  --optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
  --epoch_size 200000                                  \
  --max_epoch 30                                       \
  --eval_bleu true                                     \
  --reload_model "$MODEL,$MODEL"                       \

我们还提供了一个在WMT16 en-ro双语数据集上使用大规模预训练模型的演示。我们提供预先培训和微调的模型:

^{tb2}$

通过以下命令下载数据集:

wget https://dl.fbaipublicfiles.com/XLM/codes_enro
wget https://dl.fbaipublicfiles.com/XLM/vocab_enro

./get-data-bilingual-enro-nmt.sh --src en --tgt ro --reload_codes codes_enro --reload_vocab vocab_enro

从上面的链接下载质量预训练模型后。并使用以下命令进行微调:

MODEL=mass_enro_1024.pth

python train.py \
	--exp_name unsupMT_enro                              \
	--data_path ./data/processed/en-ro                   \
	--lgs 'en-ro'                                        \
	--bt_steps 'en-ro-en,ro-en-ro'                       \
	--encoder_only false                                 \
	--mt_steps 'en-ro,ro-en'                             \
	--emb_dim 1024                                       \
	--n_layers 6                                         \
	--n_heads 8                                          \
	--dropout 0.1                                        \
	--attention_dropout 0.1                              \
	--gelu_activation true                               \
	--tokens_per_batch 2000                              \
	--batch_size 32                                      \
	--bptt 256                                           \
	--optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
	--epoch_size 200000                                  \
	--max_epoch 50                                       \
	--eval_bleu true                                     \
	--reload_model "$MODEL,$MODEL"

培训详情

MASS-base-uncased使用32xnvidia32gbv100gpu,在(Wikipekia+BookCorpus,16GB)上训练20个时代(float32),批量大小模拟为4096。在

其他问题

  1. Q: When I run this program in multi-gpus or multi-nodes, the program reports errors like ModuleNotFouldError: No module named 'mass'.
    A: This seems a bug in python multiprocessing/spawn.py, a direct solution is to move these files into each relative folder under fairseq. Do not forget to modify the import path in the code.

参考文献

如果你在工作中发现大量有用,你可以引用以下文章:

@inproceedings{song2019mass,
    title={MASS: Masked Sequence to Sequence Pre-training for Language Generation},
    author={Song, Kaitao and Tan, Xu and Qin, Tao and Lu, Jianfeng and Liu, Tie-Yan},
    booktitle={International Conference on Machine Learning},
    pages={5926--5936},
    year={2019}
}

相关工程

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
maven字段#getGenericType()抛出java。lang.TypeNotPresentException   用java绘制三角形的几何图形   java无法下载主题和发件人地址(rediff)   java如何使代码线程安全   java在尝试转换FileInputStream中的文件时,我遇到了一个FileNotFound异常   java Moxy和Jackson如何将Json映射到Pojo   在foreach循环中使用BufferedWriter生成新行的java问题   java为什么我的测试在单次执行中运行时间小于1秒,而在maven构建中运行时间大于20秒?   java如何显示下载附件的进度条   了解java rmi的良好实践   .net可以将Java portlet嵌入ASP。网页?   循环如何多次执行Java方法?   java如何确保用户输入在给定的有效范围内?   java单元测试定理   java如何在IntelliJ上运行外部构建项目?   JAVA:试图编写一个检查字符串是否为数字的方法。总是返回错误   javahadoop将特定键的所有map方法生成的所有值都发送到一个reduce方法,对吗?   在java中读取和使用文件