Python:OSError无法加载bert的配置

2024-06-07 04:46:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试为一项任务训练一个bert-base-multilingual-uncased模型。我的数据集中有所有必需的文件,包括config.jsonbert文件,但是当我运行模型时,它会给出一个错误

配置

class config:
    DEVICE = "cuda:2"
    MAX_LEN = 256
    TRAIN_BATCH_SIZE = 8
    VALID_BATCH_SIZE = 4
    EPOCHS = 1
    BERT_PATH = "workspace/data/jigsaw-multilingual/input/bert-base-multilingual-uncased"
    MODEL_PATH = "workspace/data/jigsaw-multilingual/model.bin"
    TOKENIZER = BertTokenizer.from_pretrained('bert-base-multilingual-uncased', do_lower_case=True)

型号

class BERTBaseUncased(nn.Module):
    def __init__(self):
        super(BERTBaseUncased, self).__init__()
        self.bert = transformers.BertModel.from_pretrained(config.BERT_PATH)
        self.bert_drop = nn.Dropout(0.3)
        self.out = nn.Linear(768 * 2, 1) # *2 since we have 2 pooling layers

    def forward(self, ids, mask, token_type_ids):
        o1, _ = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        
        mean_pooling = torch.mean(o1, 1)
        max_pooling, _ = torch.max(o1, 1)
        cat = torch.cat((mean_pooling, max_pooling), 1)
        
        bo = self.bert_drop(cat)
        output = self.out(bo)
        return output

错误

---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/transformers/configuration_utils.py in get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    241             if resolved_config_file is None:
--> 242                 raise EnvironmentError
    243             config_dict = cls._dict_from_json_file(resolved_config_file)

OSError: 

During handling of the above exception, another exception occurred:

OSError                                   Traceback (most recent call last)
<ipython-input-64-9f2999c88020> in <module>
     79 
     80 if __name__ == "__main__":
---> 81     run()

<ipython-input-64-9f2999c88020> in run()
     38 
     39     device = torch.device(config.DEVICE)
---> 40     model = BERTBaseUncased()
     41     model.to(device)
     42 

<ipython-input-60-8e1508eac60a> in __init__(self)
      2     def __init__(self):
      3         super(BERTBaseUncased, self).__init__()
----> 4         self.bert = transformers.BertModel.from_pretrained(config.BERT_PATH)
      5         self.bert_drop = nn.Dropout(0.3)
      6         self.out = nn.Linear(768 * 2, 1) # *2 since we have 2 pooling layers

/opt/conda/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    601                 proxies=proxies,
    602                 local_files_only=local_files_only,
--> 603                 **kwargs,
    604             )
    605         else:

/opt/conda/lib/python3.6/site-packages/transformers/configuration_utils.py in from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
    198 
    199         """
--> 200         config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
    201         return cls.from_dict(config_dict, **kwargs)
    202 

/opt/conda/lib/python3.6/site-packages/transformers/configuration_utils.py in get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    249                 f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
    250             )
--> 251             raise EnvironmentError(msg)
    252 
    253         except json.JSONDecodeError:

OSError: Can't load config for 'workspace/data/jigsaw-multilingual/input/bert-base-multilingual-uncased'. Make sure that:

- 'workspace/data/jigsaw-multilingual/input/bert-base-multilingual-uncased' is a correct model identifier listed on 'https://huggingface.co/models'

- or 'workspace/data/jigsaw-multilingual/input/bert-base-multilingual-uncased' is the correct path to a directory containing a config.json file

这些是我的bert数据集中存在的文件:
-&燃气轮机config.json
-&燃气轮机pytorch_model.bin
-&燃气轮机vocab.txt

如何解决这个问题


Tags: orpathnameinfromselfconfiginput