如何使用预训练模型查找cls令牌

2024-04-26 10:19:55 发布

您现在位置:Python中文网/ 问答频道 /正文

在我的模型中输入了我的input_id和input_id_type之后,我得到了一个1d数组,而不是{3d,2d}元组

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
cls = []    
for text in sliced_text:
    input_ids = tokenizer.encode(text))
    token_type_ids = [0 if i <= input_ids.index(102) else 1
              for i in range(len(input_ids))]
    with torch.no_grad():
        cls_token = model.bert(torch.tensor([input_ids]),token_type_ids=torch.tensor(token_type_ids))