PyTorch:保存权重和模型定义

2024-04-26 11:50:30 发布

您现在位置:Python中文网/ 问答频道 /正文

在原型设计过程中,我经常对PyTorch模型进行多次更改。例如,假设我正在试验的第一个模型是:

class Model(nn.Module):
    def __init__(self, **args):
        super().__init__()
        self.l1 = nn.Linear(128, 1)

然后我将添加另一层:

class Model(nn.Module):
    def __init__(self, **args):
        super().__init__()
        self.l1 = nn.Linear(128, 32)
        self.l2 = nn.Linear(32, 1)

或者加上一些卷积,等等

问题是,我执行的实验越多,我经常会变得杂乱无章,因为我还没有找到一种简单的方法来保存模型定义及其权重,以便加载以前的状态

我知道我能做到:

torch.save({'model': Model(), 'state': model.state_dict()}, path)
# or directly
torch.save(model, path)

但是,加载模型还需要模型类(这里,Model)存在于当前文件中

Keras中,您只需执行以下操作:

model = ...  # Get model (Sequential, Functional Model, or Model subclass)
model.save('path/to/location')

这节省了模型的架构/配置和权重等。这意味着您可以在没有定义体系结构的情况下加载模型:

model = keras.models.load_model('path/to/location')

关于Keras model saving

The SavedModel and HDF5 file contains:

  • the model's configuration (topology)
  • the model's weights
  • the model's optimizer's state (if any)

Thus models can be reinstantiated in the exact same state, without any of the code used for model definition or training.

这就是我想在PyTorch中实现的目标

PyTorch是否有类似的方法?对于这些情况,最佳做法是什么


Tags: orthepath模型selfmodelinitsave
2条回答

@D Hudson's answer是正确的选择。但是,为了将来的参考,我想添加以下对我有用的方法

让我们假设模型的forward方法是固定的,也就是说,只改变了底层架构,输入和;输出形状。在本例中,我们只对表示整个体系结构的Sequential属性感兴趣:

class Model(nn.Module):
    def __init__(self, **hparams):
        super(Model, self).__init__()
        
        # this attribute is the only thing we care about
        self.net = nn.Sequential(
            # experiment with different layers here ...
        )
        
    def forward(self, x):
        return self.net(x) # this is fixed!

然后,我们可以保存模型架构(基本上只是net属性)及其权重:

m = Model()
# train/test/valid ...
T.save({'net': m.net, 'weights': m.state_dict()}, './version1.pth')

最后,按照如下方式执行加载:

m = Model()
checkpoint = T.load('./version1.pth')
m.net = checkpoint['net']
m.load_state_dict(checkpoint['weights'])

由于Pytorch在模型中提供了巨大的灵活性,因此在单个文件中保存体系结构和权重将是一个挑战。Keras模型通常仅通过堆叠Keras组件来构建,但pytorch模型由库使用者以自己的方式编排,因此可以包含任何类型的逻辑

我认为你有三个选择:

  1. 为你的实验提出一个有组织的模式,这样就不太可能丢失模型定义。您可以选择一些简单的方法,例如通过仅定义每个模型的模式命名的文件。我会推荐这种方法,因为这种级别的组织可能会从其他方面受益,并且开销最小

  2. 尝试将代码与pickle文件一起保存。虽然有可能,但我认为这会让你陷入一个有很多潜在问题的困境

  3. 使用不同的标准化方法保存模型,例如^{}。如果您不想选择选项1,我建议您选择这条路线。Onnx确实允许您保存pytorch模型的架构及其权重,但也有一些缺点。例如,它只支持某些操作,因此完全自定义的forward方法或使用非矩阵操作可能不起作用

相关问题 更多 >