在pytorch中创建一个新模型,并为权重设置自定义初始值

2024-03-29 07:22:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我是pytorch的新手,我想了解如何设置网络第一个隐藏层的初始权重。我解释得更好一点:我的网络是一个非常简单的一层MLP,有784个输入值和10个输出值

 class Classifier(nn.Module):
        def __init__(self):
          super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
        # Dropout module with 0.2 drop probability
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        # make sure input tensor is flattened
        # x = x.view(x.shape[0], -1)

        # Now with dropout
        x = self.dropout(F.relu(self.fc1(x)))

        # output so no dropout here
        x = F.log_softmax(self.fc2(x), dim=1)

        return x 

现在,我有一个形状的numpy矩阵(128784),它包含了我想要的fc1中的权重值。如何使用矩阵中包含的值初始化第一层的权重?你知道吗

在网上搜索其他答案,我发现我必须为权重定义init函数

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

但我不懂密码


Tags: self网络datainitdefwithnnclass
1条回答
网友
1楼 · 发布于 2024-03-29 07:22:53

您可以简单地使用torch.nn.Parameter()为您的网络层分配一个自定义权重。你知道吗

就像你的情况一样-

model.fc1.weight = torch.nn.Parameter(custom_weight)

torch.nn.Parameter:一种张量,被认为是模参数。你知道吗

例如

# Classifier model
model = Classifier()

# your custom weight, here taking randam
custom_weight = torch.rand(model.fc1.weight.shape)
custom_weight.shape
torch.Size([128, 784])

# before assign custom weight
print(model.fc1.weight)
Parameter containing:
tensor([[ 1.6920e-02,  4.6515e-03, -1.0214e-02,  ..., -7.6517e-03,
          2.3892e-02, -8.8965e-03],
        ...,
        [-2.3137e-02,  5.8483e-03,  4.4392e-03,  ..., -1.6159e-02,
          7.9369e-03, -7.7326e-03]])

# assign custom weight to first layer
model.fc1.weight = torch.nn.Parameter(custom_weight)

# after assign custom weight
model.fc1.weight
Parameter containing:
tensor([[ 0.1724,  0.7513,  0.8454,  ...,  0.8780,  0.5330,  0.5847],
        [ 0.8500,  0.7687,  0.3371,  ...,  0.7464,  0.1503,  0.7720],
        [ 0.8514,  0.6530,  0.6261,  ...,  0.7867,  0.9312,  0.3890],
        ...,
        [ 0.5426,  0.7655,  0.1191,  ...,  0.4343,  0.2500,  0.6207],
        [ 0.2310,  0.4260,  0.4138,  ...,  0.1168,  0.5946,  0.2505],
        [ 0.4220,  0.5500,  0.6282,  ...,  0.5921,  0.7953,  0.9997]])

相关问题 更多 >