我想要一个Module
的PyTorch子类,它将子模块保存在一个列表中(因为根据构造函数的参数,子模块的数量可能是可变的)。我按以下方式设置此列表:
self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]
根据this和this问题,当__setattr__
对象被分配给self
属性时,子模块仅由__setattr__
注册。由于hidden_layers
未分配类型为Module
的对象,因此列表中的子模块未注册为子模块,因此self.parameters()
不会迭代子模块的参数
我想我可以为列表中的每个元素显式地调用__subattr__
,但那会非常难看。有没有更正确的方法来注册不是Module
的直接属性的子模块
正如所回答的
nn.ModuleList
就是你想要的您还可以使用
nn.Sequential
。您可以创建一个层列表,然后通过nn.Sequential
将它们组合起来,这将只是一个包装器,将所有层组合成一个基本层/模块。这样做的好处是,您只需要一个调用就可以将其转发到所有层,如果您有一个动态的模块计数,那么这很好,这样您就不必自己编写循环一个例子是pytorch ResNet代码:https://github.com/pytorch/vision/blob/497744b9d510ff2df756f479ee5a19fce0d579b6/torchvision/models/resnet.py#L177
使用
nn.ModuleList
相关问题 更多 >
编程相关推荐