Pytorch模型概述

2024-04-20 12:51:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用以下方式加载pytorch模型:

model = torch.load('/content/gdrive/model.pth.tar', map_location='cpu')

我想查看模型摘要。当我尝试时:

print(model)

我得到以下输出:

{'state_dict': {'model.conv1.weight': tensor([[[[ 2.0076e-02,  1.5264e-02, -1.2309e-02,  ..., -4.0222e-02,
           -4.0527e-02, -6.4458e-02],
          [ 6.3291e-03,  3.8393e-03,  1.2400e-02,  ..., -3.3926e-03,
           -2.1063e-02, -3.4743e-02],
          [ 1.9969e-02,  2.0064e-02,  1.4004e-02,  ...,  8.7359e-02,
            5.4801e-02,  4.8791e-02],
          ...,
          [ 2.5362e-02,  1.1433e-02, -7.6776e-02,  ..., -3.4798e-01,
           -2.7198e-01, -1.2066e-01],
          [ 8.0373e-02,  1.3095e-01,  1.4240e-01,  ..., -2.2933e-03,
           -1.0469e-01, -1.0922e-01],
          [-1.1147e-03,  7.4572e-02,  1.2814e-01,  ...,  1.6903e-01,
            1.0619e-01,  2.4744e-02]], 
      'model.layer4.1.bn2.running_var': tensor([0.0271, 0.0155, 0.0199, 0.0198, 0.0132, 0.0148, 0.0182, 0.0170, 0.0134,
.
.
.

这到底是什么意思

我还尝试使用:

from torchsummary import summary
summary(model, input_size=(3, 224, 224))

但它给了我以下错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-31-ca828d30bd38> in <module>()
      1 from torchsummary import summary
----> 2 summary(model, input_size=(3, 224, 224))

/usr/local/lib/python3.7/dist-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
     66 
     67     # register hook
---> 68     model.apply(register_hook)
     69 
     70     # make a forward pass

AttributeError: 'dict' object has no attribute 'apply'

请注意,该模型是我正在尝试加载的自定义模型。 如何在Pytorch中获取模型摘要


Tags: infrom模型importregisterinputsizemodel
1条回答
网友
1楼 · 发布于 2024-04-20 12:51:20

您加载了“*.pt”并没有将其提供给模型(这只是一个权重字典,取决于您保存的内容),这就是为什么您会得到以下输出:

{'state_dict': {'model.conv1.weight': tensor([[[[ 2.0076e-02,  1.5264e-02, -1.2309e-02,  ..., -4.0222e-02,
           -4.0527e-02, -6.4458e-02],
          [ 6.3291e-03,  3.8393e-03,  1.2400e-02,  ..., -3.3926e-03,
           -2.1063e-02, -3.4743e-02],
          [ 1.9969e-02,  2.0064e-02,  1.4004e-02,  ...,  8.7359e-02,
            5.4801e-02,  4.8791e-02],
          ...,
          [ 2.5362e-02,  1.1433e-02, -7.6776e-02,  ..., -3.4798e-01,
           -2.7198e-01, -1.2066e-01],
          [ 8.0373e-02,  1.3095e-01,  1.4240e-01,  ..., -2.2933e-03,
           -1.0469e-01, -1.0922e-01],
          [-1.1147e-03,  7.4572e-02,  1.2814e-01,  ...,  1.6903e-01,
            1.0619e-01,  2.4744e-02]], 
      'model.layer4.1.bn2.running_var': tensor([0.0271, 0.0155, 0.0199, 0.0198, 0.0132, 0.0148, 0.0182, 0.0170, 0.0134,
.
.
.

你应该做的是:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
print(model)

您可以参考pytorch doc

关于您的第二次尝试,同样的问题导致了问题,summary期望的是一个模型,而不是一个权重字典

相关问题 更多 >