如何在AllenNLP中加载微调的sciBERT模型?

2024-05-12 18:52:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经在SCIE数据集上对SciBERT模型进行了微调。存储库使用AllenNLP对模型进行微调。培训内容如下:

python -m allennlp.run train $CONFIG_FILE  --include-package scibert -s "$@" 

经过成功的培训后,我有一个model.tar.gz文件作为输出,其中包含weights.th、config.json和词汇文件夹。我已尝试将其加载到allenlp预测器中:

from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path("model.tar.gz")

但我得到了以下错误:

ConfigurationError: bert-pretrained not in acceptable choices for dataset_reader.token_indexers.bert.type: ['single_id', 'characters', 'elmo_characters', 'spacy', 'pretrained_transformer', 'pretrained_transformer_mismatched']. You should either use the --include-package flag to make sure the correct module is loaded, or use a fully qualified class name in your config file like {"model": "my_module.models.MyModel"} to have it imported automatically.

我从未与allenNLP合作过,所以我对该做什么感到困惑

作为参考,这是描述令牌索引器的配置部分

"token_indexers": {
            "bert": {
                "type": "bert-pretrained",
                "do_lowercase": "false",
                "pretrained_model": "/home/tomaz/neo4j/scibert/model/vocab.txt",
                "use_starting_offsets": true
            }
        }

我正在使用allenlp版本

姓名:allennlp 版本:1.2.1

编辑:

我想我已经取得了很大的进步,我必须使用用于训练模型的相同版本,我可以像这样导入模块:

from allennlp.predictors.predictor import Predictor
from scibert.models.bert_crf_tagger import *
from scibert.models.bert_text_classifier import *
from scibert.models.dummy_seq2seq import *
from scibert.dataset_readers.classification_dataset_reader import *

predictor = Predictor.from_path("scibert_ner/model.tar.gz")
dataset_reader="classification_dataset_reader")
predictor.predict(
  sentence="Did Uriah honestly think he could beat The Legend of Zelda in under three hours?"
)

现在我得到一个错误:

No default predictor for model type bert_crf_tagger.\nPlease specify a predictor explicitly

我知道我可以使用predictor_name来显式地指定一个预测器,但是我一点儿也不知道选择哪个名称会起作用


Tags: infromimportmodelmodelstarpredictordataset
1条回答
网友
1楼 · 发布于 2024-05-12 18:52:36

我见过很多人有这个问题。查看存储库代码后,我发现这是运行预测的最简单方法:

python -m allennlp.run predict /path/to/saved_model/model.tar.gz /path/to/test.txt\
   include-package scibert  use-dataset-reader\
   output-file /path/to/where/you/want/predict.txt\
   predictor  sentence-tagger  batch-size 16

我补充了什么?预测器sentence-tagger。浏览存储库后,您会发现注册的预测值是sentence-tagger。尽管标记器的DEFAUL_DICT包含sentence_tagger。有很多困惑,对吗?告诉我

这个答案还可以避免你写predictor

相关问题 更多 >