如何恢复BERT/XLNet嵌入?

2024-05-15 15:25:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我最近一直在尝试堆叠语言模型,并注意到一些有趣的事情:BERT和XLNet的输出嵌入与输入嵌入不同。例如,以下代码段:

bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tok = transformers.BertTokenizer.from_pretrained("bert-base-cased")

sent = torch.tensor(tok.encode("I went to the store the other day, it was very rewarding."))
enc = bert.get_input_embeddings()(sent)
dec = bert.get_output_embeddings()(enc)

print(tok.decode(dec.softmax(-1).argmax(-1)))

给我这个:

,,,,,,,,,,,,,,,,,

我本来希望返回(格式化的)输入序列,因为我觉得输入和输出令牌嵌入是绑定的

有趣的是,大多数其他模型都没有表现出这种行为。例如,如果在GPT2、Albert或Roberta上运行相同的代码段,它将输出输入序列

这是虫子吗?或者是预期的BERT/XLNet


Tags: thefrom模型baseget代码段sentbert
1条回答
网友
1楼 · 发布于 2024-05-15 15:25:36

不确定是否为时已晚,但我已经对您的代码进行了一些实验,它可以恢复。:)

bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tok = transformers.BertTokenizer.from_pretrained("bert-base-cased")

sent = torch.tensor(tok.encode("I went to the store the other day, it was very rewarding."))
print("Initial sentence:", sent)
enc = bert.get_input_embeddings()(sent)
dec = bert.get_output_embeddings()(enc)

print("Decoded sentence:", tok.decode(dec.softmax(0).argmax(1)))

为此,您将获得以下输出:

Initial sentence: tensor([  101,   146,  1355,  1106,  1103,  2984,  1103,  1168,  1285,   117,
         1122,  1108,  1304, 10703,  1158,   119,   102])  
Decoded sentence: [CLS] I went to the store the other day, it was very rewarding. [SEP]

相关问题 更多 >