填充和注意掩码在GPT语言模型的批输入中无法正常工作

2024-06-16 08:58:33 发布

您现在位置:Python中文网/ 问答频道 /正文

以下代码没有批处理:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
context=torch.tensor([tokenizer.encode("This is")])
output, past = model(context)
token = torch.argmax(output[..., -1, :])
print(tokenizer.decode(token.item()))

output: ' a'

这很好用。现在,我将其扩展到批处理设置:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

context=[torch.tensor(tokenizer.encode("This is ")),torch.tensor(tokenizer.encode("Hello How are "))]
context=pad_sequence(context,batch_first=True)

mask=torch.tensor([[1,1,0],[1,1,1]])
output, past = model(context,attention_mask=mask)
token = torch.argmax(output[..., -1, :],dim=1)
tokenizer.decode(token)

output: '\n you'

这里\n是第一个上下文的下一个标记,you是批处理的第二个上下文的下一个标记。 但是第一个上下文的预期下一个标记是a,因为所有设置都是相同的。此外,如果您将第二个上下文减少为2个令牌,那么您将在该批处理设置中获得a。很明显,模型无法理解填充。 此外,注意力面罩也不起作用。因为 填充后,序列this is的下一个标记为0(零)。根据注意掩码([1,1,0]),应该避免这个零,并且只应该注意标记thisis。证明这种注意力掩蔽不起作用的证据有:

  • 使用注意掩码[1,1,1],这意味着即使在填充零上,也会得到相同的输出 这是\n

  • 使用字符串this is!。这里!在词汇表矩阵中有零索引。同样,您会得到相同的输出,即\n

只有在没有批量设置和注意遮罩的情况下,才能获得理想的输出(现在看来,这并不重要,因为它无论如何都没有效果)

然后我找到了this,它建议使用pad_令牌。所以我用了如下的方法:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from torch.nn.utils.rnn import pad_sequence  

tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token="<PAD>")
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

context=[torch.tensor(tokenizer.encode("This is <PAD> ")),torch.tensor(tokenizer.encode("Hello How are"))]
context=torch.stack(context)
print(context)
mask=torch.tensor([[1,1,0],[1,1,1]])

output, past = model(context,attention_mask=mask)
token = torch.argmax(output[..., -1, :],dim=1)
tokenizer.decode(token)

output: 'The you'

这里The是第一个上下文的下一个标记,you是批处理的第二个上下文的下一个标记。这也不起作用。因为第一个上下文不需要The

如何在gpt/gpt2模型的批次设置中使用可变长度序列


Tags: from标记importtokenoutputmodeliscontext