在Keras实施的变压器XL
keras-transformer-xl的Python项目详细描述
Keras变压器XL
非正式实施Transformer-XL。
安装
pip install keras-transformer-xl
用法
负载预应变重量
可以在the info directory找到几个配置文件。
importosfromkeras_transformer_xlimportload_trained_model_from_checkpointcheckpoint_path='foo/bar/sota/enwiki8'model=load_trained_model_from_checkpoint(config_path=os.path.join(checkpoint_path,'config.json'),checkpoint_path=os.path.join(checkpoint_path,'model.ckpt'))model.summary()
关于IO
生成的模型有两个输入,第二个输入是存储器的长度。
您可以使用MemorySequence
包装器进行训练和预测:
importkerasimportnumpyasnpfromkeras_transformer_xlimportMemorySequence,build_transformer_xlclassDummySequence(keras.utils.Sequence):def__init__(self):passdef__len__(self):return10def__getitem__(self,index):returnnp.ones((3,5*(index+1))),np.ones((3,5*(index+1),3))model=build_transformer_xl(units=4,embed_dim=4,hidden_dim=4,num_token=3,num_block=3,num_head=2,batch_size=3,memory_len=20,target_len=10,)seq=MemorySequence(model=model,sequence=DummySequence(),target_len=10,)model.predict(model,seq,verbose=True)
使用tf.keras
将TF_KERAS=1
添加到环境变量中以使用tensorflow.python.keras
。