保存和加载PyTorch模型的问题
我正在尝试设计一个神经网络,用来通过PyTorch模型对信号进行分类。在训练完模型后,我会保存这个模型,然后在推理阶段加载它。
我训练完模型后,用下面的代码保存模型:
torch.save(model.state_dict(), ".pth")
然后我想在推理阶段加载这个模型。我使用了以下代码:
model_test = spectrogram_model(X_Test)
model_test.load_state_dict(torch.load(".pth"))
但是我遇到了一个错误,错误信息如下:
AttributeError: 'Tensor' object has no attribute 'load_state_dict'
spectrogram_model的内容如下:
class Spectrogram(nn.Module):
def __init__(self):
super().__init__()
self.dropout = nn.Dropout(0.04)
self.hidden1 = nn.Linear(16, 12)
self.act1 = nn.ReLU()
self.hidden2 = nn.Linear(12, 8)
self.act2 = nn.ReLU()
self.hidden3 = nn.Linear(8, 4)
def forward(self, x):
x = self.dropout(x)
x = self.act1(self.hidden1(x))
x = self.act2(self.hidden2(x))
x = self.hidden3(x)
return x
你有什么解决方案吗?
0 个回答
暂无回答