Pytorch1.2.0 CrossEntropyLoss错误:仅支持批量空间目标(3D张量),但得到的目标尺寸为

2024-03-28 07:03:24 发布

您现在位置:Python中文网/ 问答频道 /正文

刚刚为我的基本MLP重铺了一个旧模板,我有点生锈了。我在MNIST(张量转换图像)上使用CELoss

我试图不考虑批量大小,但它确实坚持目标张量为dim=3。 如果有人能解释为什么现在需要这样做,我们将不胜感激

PS.似乎与下面的帖子有关,但没有得到回复here.

我的代码

train_dataset = datasets.MNIST('../', download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST('../', train=False, download=True, transform=transforms.ToTensor())

indices = list(range(len(train_dataset.data)))
indices = list(range(len(test_dataset.data)))

random.shuffle(indices)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=batchSize,
                                          num_workers=1)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batchSize,
                                          num_workers=1)


# Surrogate loss used for training
loss_fn = nn.CrossEntropyLoss()
test_loss_fn = nn.CrossEntropyLoss(reduction='sum')


optimizer = optim.Adam(model.parameters(), lr=lr)
#optimizer = optim.SGD(model.parameters(), lr=lr ,weight_decay=weight_decay)


print('Training beginning...')
start_time = time.time()

for epoch in range(1, nbr_epochs+1):
    print('Epoch ', epoch, ':')
    train(model, train_loader, optimizer, epoch,loss_fn)
    loss, acc = test(model, test_loader)

    test_accuracy.append(acc)
    train_loss.append(loss)

def train(model,train_loader, optimizer, epoch, loss_fn):
    model.train()

    for batch_idx, (inputs, target) in enumerate(train_loader):

        #inputs = inputs.view(batchSize, 1,100,100)
        inputs, target = inputs.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(inputs)
        loss = loss_fn(output, target.unsqueeze(1))

        # Backprop
        loss.backward()
        optimizer.step()



错误是:


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-31-c6a8cb2b8f18> in <module>
     22 for epoch in range(1, nbr_epochs+1):
     23     print('Epoch ', epoch, ':')
---> 24     train(model, train_loader, optimizer, epoch,loss_fn)
     25     loss, acc = test(model, test_loader)
     26 

<ipython-input-29-d78d7fdaeb4d> in train(model, train_loader, optimizer, epoch, loss_fn)
     10         optimizer.zero_grad()
     11         output = model(inputs)
---> 12         loss = loss_fn(output, target.unsqueeze(1))
     13 
     14         # Backprop

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/.local/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    914     def forward(self, input, target):
    915         return F.cross_entropy(input, target, weight=self.weight,
--> 916                                ignore_index=self.ignore_index, reduction=self.reduction)
    917 
    918 

~/.local/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   1993     if size_average is not None or reduce is not None:
   1994         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1995     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   1996 
   1997 

~/.local/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1824         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1825     elif dim == 4:
-> 1826         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1827     else:
   1828         # dim == 3 or dim > 4

RuntimeError: invalid argument 3: only batches of spatial targets supported (3D tensors) but got targets of dimension: 2 at /pytorch/aten/src/THNN/generic/SpatialClassNLLCriterion.c:61

其中train()for loop的输入是shape(32,1,28,28),而有问题的目标是(32,1)


Tags: intestselftargetinputmodeltrainnn