一般来说, 在构造函数中,我们声明要使用的所有层。 在forward函数中,我们定义了从输入到输出模型的运行方式
我的问题是,如果直接在forward()
函数中调用这些预定义/内置的nn.Modules,会怎么样?这种Keras function API风格的Pytorch是否合法?若否,原因为何
更新:以这种方式构建的TestModel确实成功运行,没有出现警报。但与传统的训练方式相比,训练损耗下降缓慢
import torch.nn as nn
from cnn import CNN
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_embeddings = 2020
self.embedding_dim = 51
def forward(self, input):
x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
# CNN is a customized class and nn.Module subclassed
# we will ignore the arguments for its instantiation
x = CNN(...)(x)
x = nn.ReLu()(x)
x = nn.Dropout(p=0.2)(x)
return output = x
你想做的事是可以做的,但不应该做,因为在大多数情况下这是完全不必要的。在我看来,这并不更具可读性,而且肯定与PyTorch的方式背道而驰
在
forward
中,层每次都会重新初始化,并且它们不会在网络中注册要正确执行此操作,您可以使用
Module
的add_module()
函数来防止重新分配(下面的方法dynamic
):你可以用不同的方式构建它,但这正是它背后的理念
真正的用例可能是当层的创建在某种程度上依赖于传递给
forward
的数据时,但这可能表明程序设计中存在一些缺陷您需要考虑可训练参数的范围
例如,如果您在模型的
forward
函数中定义一个conv层,那么这个“层”及其可训练参数的范围是该函数的局部,并且在每次调用forward
方法后都将被丢弃。您无法更新和训练每次forward
通过后不断丢弃的权重。但是,当conv层是
model
的成员时,它的作用域超出了forward
方法,并且只要model
对象存在,可训练参数就会持续存在。这样可以更新和训练模型及其权重相关问题 更多 >
编程相关推荐