预期4维权重[6,1,5,5]的4维输入,但得到了大小为[1,28,28]的3维输入

2024-04-19 06:41:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试建立一个足够复杂的神经网络来适应数据(我使用的是MNIST数据集)。我有一个小网络,我现在尝试建立一个新的网络,我偶然发现了这个问题。代码是:

class NN1(nn.Module):

    def __init__(self):
        super(NN1, self).__init__()
       
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

transform_list = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.0], std=[1.0,]) ] )

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_list)

mnist_trainset_small =  [ mnist_trainset[i] for i in range(0,4000) ] 

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_list)

nn1 = NN1()

tmp = nn1.forward( mnist_trainset[0][0])
tmp

我如何通过构建良好的网络来解决这个问题


Tags: 数据self网络truedeftransformnnlist
1条回答
网友
1楼 · 发布于 2024-04-19 06:41:57

您应该在^{}上使用^{}

mnist_train_dl = torch.utils.data.DataLoader(mnist_trainset, batch_size=16)

预定义的Pytorch模块使用批处理优先张量。在您的例子中,您的模型需要一个^{形状的张量

您不应该调用forward,而是直接调用您的模块nn1(x)

通常,您会在数据加载器中循环并推断/反向传播/更新每个批。比如:

for x, y in mnist_train_dl:
    out = nn1(x)
    # ...

但是,您可以通过访问第一批的第一个元素来推断一个元素来调试模型:

x, y = next(mnist_train_dl)
out = nn1(x[:1]) # target is y[:1]

通过[:1]而不是[0]进行索引,这样就不会挤压第一个轴

相关问题 更多 >