在不计算整个句子的情况下估计给定句子的标记概率/逻辑

2024-06-16 12:27:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个像这样的句子:"I like sitting in my new chair and _____ about life"

我有一组特定的标记,比如["watch", "run", "think", "apple", "light"]

我想计算每个标记作为下一个单词出现在不完整句子中的概率。希望我能得到"think"的概率高于"apple"的概率

我正在使用pytorch transformers(特别是GPT2LMHeadModel),一个可能的解决方案是使用每个标记评估完整句子的分数,但当要评估的标记数大约为100或1000时,计算时间开始过长

必须能够只处理一次句子,并以某种方式使用隐藏状态来计算标记集的概率,但我不知道如何做到这一点

有什么想法吗?提前谢谢


编辑:

实际代码如下所示(每次估计完整句子的概率)。对于每一个句子,运行score()方法大约需要0.1秒,如果我想计算数千个单词,它会变成数小时

from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")


def score(sentence):
    tokenize_input = tokenizer.tokenize(sentence)
    tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
    loss = model(tensor_input, labels=tensor_input)
    return -loss[0].item()


candidates = ["watch", "run", "think", "apple", "light"]
sent_template = "I like sitting in my new chair and {} about life"
print({candidate: score(sent_template.format(candidate)) for candidate in candidates})

Tags: infrom标记appleinputmodel概率candidate
1条回答
网友
1楼 · 发布于 2024-06-16 12:27:43

您的示例生成了以下输出,并在我的环境中用了大约48.5秒完成了282个糖果(我只运行了3次):

{'watch': -5.406847953796387
, 'run': -5.533411502838135
, 'think': -4.525279521942139
, 'apple': -6.158637046813965
, 'light': -5.835141658782959}

如评论中所述,我认为您可以使用past参数和fast tokenizer进行一些计算,如下面评论的示例所示:

import torch

from  transformers import GPT2TokenizerFast, GPT2LMHeadModel
from torch.nn import CrossEntropyLoss

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

###We calculate the hidden_states and the past of the common left part of the sentence
past = "I like sitting in my new chair and"
past_tokenize_input = tokenizer.tokenize(past)
past_tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(past_tokenize_input)])

past_last_hidden_state, past = model.transformer(past_tensor_input)

def score(sentence, past, past_last_hidden_state, past_tensor_input):
    tokenize_input = tokenizer.tokenize(sentence, )
    tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])

    ###the following code is slightly modified from https://github.com/huggingface/transformers/blob/09a2f40684f77e62d0fd8485fe9d2d610390453f/src/transformers/modeling_gpt2.py#L604
    ###now we calculate the right part of the sentence with the already calculated past
    transformer_outputs = model.transformer(
            tensor_input,
            past=past,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
        )
    ###and concatenate the output of with the hidden_state of the left part of the sentence
    hidden_states = torch.cat((past_last_hidden_state, transformer_outputs[0]), dim=1)
    
    ###the following part is exactly the same as https://github.com/huggingface/transformers/blob/09a2f40684f77e62d0fd8485fe9d2d610390453f/src/transformers/modeling_gpt2.py#L604
    lm_logits = model.lm_head(hidden_states)

    labels_input = torch.cat((past_tensor_input, tensor_input), dim=1)

    # Shift so that tokens < n predict n
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels_input[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return -loss.item()

candidates = ["watch", "run", "think", "apple", "light"]

sent_template = " {} about life"

print({candidate: score(sent_template.format(candidate), past, past_last_hidden_state, past_tensor_input) for candidate in candidates})

输出:

{'watch': -5.406846046447754
, 'run': -5.533413887023926
, 'think': -4.525280952453613
, 'apple': -6.158637046813965
, 'light': -5.835141181945801}

这里的运行时间是40.5秒,有282个候选者(又是3个循环)。你也看到我失去了一些精确性

非常感谢patrickvonplaten对过去的实现给了我一个很好的explanation

相关问题 更多 >