当网络的输出类型为list时,如何在pytorch中生成网络的架构图?

2024-04-20 11:15:33 发布

您现在位置:Python中文网/ 问答频道 /正文

当输出为列表类型时,我想知道如何使用tochviz生成网络架构? 演示代码如下所示:

 import torch
 import torch.nn as nn
 class ConvNet(nn.Module):
     def __init__(self):
         super(ConvNet, self).__init__()
         self.conv1 = nn.Sequential(
             nn.Conv2d(1, 16, 3, 1, 1),
             nn.ReLU(),
             nn.AvgPool2d(2, 2)
         )
         self.conv2 = nn.Sequential(
             nn.Conv2d(16, 32, 3, 1, 1),
             nn.ReLU(),
             nn.MaxPool2d(2, 2)
         )
         self.fc = nn.Sequential(
             nn.Linear(32 * 7 * 7, 128),
             nn.ReLU(),
             nn.Linear(128, 64),
             nn.ReLU()
         )
         self.out = nn.Linear(64, 10)
     def forward(self, x):
         x = self.conv1(x)
         x = self.conv2(x)
         x = x.view(x.size(0), -1)
         x = self.fc(x)
         output = []
         output.append(x)
         output.append(self.out(x))
         return output
 MyConvNet = ConvNet()

我使用torchviz来查看这个网络的架构,就像

 from torchviz import make_dot
 x = torch.randn(1, 1, 28, 28).requires_grad_(True)
 y = MyConvNet(x)   
 MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
 MyConvNetVis.format = "png"
 MyConvNetVis.directory = "data"
 MyConvNetVis.view()

然后,我被这个问题阻碍了

AttributeError                            Traceback (most recent call last)
<ipython-input-23-c8e3cd3a8b4e> in <module>
      2 x = torch.randn(1, 1, 28, 28).requires_grad_(True)
      3 y = MyConvNet(x)
----> 4 MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
      5 MyConvNetVis.format = "png"
      6 MyConvNetVis.directory = "data"

~/anaconda3/envs/torch1.3/lib/python3.6/site-packages/torchviz/dot.py in make_dot(var, params)
     35         return '(' + (', ').join(['%d' % v for v in size]) + ')'
     36 
---> 37     output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
     38 
     39     def add_nodes(var):

AttributeError: 'list' object has no attribute 'grad_fn'

如有任何建议,将不胜感激


Tags: inimportselfoutputmakevardefnn
1条回答
网友
1楼 · 发布于 2024-04-20 11:15:33

错误表明torchviz正试图使用grad_fn在网络中导航,以便计算自己的图形。然而,元组不是张量,并且不具有gra_fn属性。 我不太确定在使用torchviz时是否可以有多个输出(即一个元组作为输出)。作为一种解决方法,如果您只想可视化网络,可以使用^{}将两个张量串联起来,替换元组

def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    out = self.out(x)
    output = torch.cat([x, out], dim=1)
    return output

结果是:

enter image description here

请注意,最后一个节点是一个CatBackward,有两个传入分支,一个来自AddmmBackwardout),另一个来自ReluBackward0x)。最后一个节点是虚构的,不存在于实际模型中,因此您可以手动将其从图形中删除

相关问题 更多 >