如何使PyTorch模块的子模块不是模块的属性

2024-03-28 20:11:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我想要一个Module的PyTorch子类,它将子模块保存在一个列表中(因为根据构造函数的参数,子模块的数量可能是可变的)。我按以下方式设置此列表:

self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]

根据thisthis问题,当__setattr__对象被分配给self属性时,子模块仅由__setattr__注册。由于hidden_layers未分配类型为Module的对象,因此列表中的子模块未注册为子模块,因此self.parameters()不会迭代子模块的参数

我想我可以为列表中的每个元素显式地调用__subattr__,但那会非常难看。有没有更正确的方法来注册不是Module的直接属性的子模块


Tags: 模块对象self列表参数数量属性layers
2条回答

正如所回答的nn.ModuleList就是你想要的

您还可以使用nn.Sequential。您可以创建一个层列表,然后通过nn.Sequential将它们组合起来,这将只是一个包装器,将所有层组合成一个基本层/模块。这样做的好处是,您只需要一个调用就可以将其转发到所有层,如果您有一个动态的模块计数,那么这很好,这样您就不必自己编写循环

一个例子是pytorch ResNet代码:https://github.com/pytorch/vision/blob/497744b9d510ff2df756f479ee5a19fce0d579b6/torchvision/models/resnet.py#L177

使用nn.ModuleList

self.hidden_layers = nn.ModuleList([torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)])

相关问题 更多 >