保存和加载PyTorch模型的问题

-1 投票
0 回答
26 浏览
提问于 2025-04-11 22:37

我正在尝试设计一个神经网络,用来通过PyTorch模型对信号进行分类。在训练完模型后,我会保存这个模型,然后在推理阶段加载它。

我训练完模型后,用下面的代码保存模型:

torch.save(model.state_dict(), ".pth")

然后我想在推理阶段加载这个模型。我使用了以下代码:

model_test = spectrogram_model(X_Test)
model_test.load_state_dict(torch.load(".pth"))

但是我遇到了一个错误,错误信息如下:

AttributeError: 'Tensor' object has no attribute 'load_state_dict'

spectrogram_model的内容如下:

class Spectrogram(nn.Module):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(0.04)
        self.hidden1 = nn.Linear(16, 12)
        self.act1 = nn.ReLU()
        self.hidden2 = nn.Linear(12, 8)
        self.act2 = nn.ReLU()
        self.hidden3 = nn.Linear(8, 4)

    def forward(self, x):
        x = self.dropout(x)
        x = self.act1(self.hidden1(x))
        x = self.act2(self.hidden2(x))
        x = self.hidden3(x)
        return x

你有什么解决方案吗?

0 个回答

暂无回答

撰写回答