Pytorch装载重物时的几个问题

2024-04-19 20:25:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我搜索了很多资源来解决这个问题,但仍然停留在这里。你知道吗

我遵循pytorch教程并使用

torch.save(the_model.state_dict(), PATH)

然后用

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

加载参数后,我打印了我的模型。结果是一个错误。你知道吗

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

我发现有些人面临同样的问题,但似乎可以忽略?!然后我尝试用```前进模式()'',出现了另一个错误。你知道吗

AttributeError: 'IncompatibleKeys' object has no attribute 'forward'

我知道这种保存方法(the_model.state_dict())就是保存“权重”。由于某些重要信息无法保存(dropout、batchnorm等),因此只能使用.eval()。所以我尝试model.eval(),它仍然有相同的错误。你知道吗

AttributeError: 'IncompatibleKeys' object has no attribute 'eval'

以下是一些相关代码:

  • 初始化模型:

    model = VAE(some constructor parameters)

  • 培训后:

    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    
  • 初始化同一模型并将参数加载到模型中:

    model = VAE(some constructor parameters)
    checkpoint = torch.load("E24.pkl", map_location='cuda:0')
    model = model.load_state_dict(checkpoint)
    

我不会再训练这个模特了。我只想加载参数,然后检查性能。谢谢你的阅读。:)


Tags: thepath模型参数modelsave错误eval