我搜索了很多资源来解决这个问题,但仍然停留在这里。你知道吗
我遵循pytorch教程并使用
torch.save(the_model.state_dict(), PATH)
然后用
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
加载参数后,我打印了我的模型。结果是一个错误。你知道吗
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
我发现有些人面临同样的问题,但似乎可以忽略?!然后我尝试用```前进模式()'',出现了另一个错误。你知道吗
AttributeError: 'IncompatibleKeys' object has no attribute 'forward'
我知道这种保存方法(the_model.state_dict()
)就是保存“权重”。由于某些重要信息无法保存(dropout、batchnorm等),因此只能使用.eval()
。所以我尝试model.eval()
,它仍然有相同的错误。你知道吗
AttributeError: 'IncompatibleKeys' object has no attribute 'eval'
以下是一些相关代码:
初始化模型:
model = VAE(some constructor parameters)
培训后:
checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
torch.save(model.state_dict(), checkpoint_path)
初始化同一模型并将参数加载到模型中:
model = VAE(some constructor parameters)
checkpoint = torch.load("E24.pkl", map_location='cuda:0')
model = model.load_state_dict(checkpoint)
我不会再训练这个模特了。我只想加载参数,然后检查性能。谢谢你的阅读。:)
目前没有回答
相关问题 更多 >
编程相关推荐