为什么我在训练数据时会出现这个错误?

2024-04-20 08:26:46 发布

您现在位置:Python中文网/ 问答频道 /正文

为什么在运行列车数据时会出现此错误? 这是我的列车代码,我面临着损失=标准(输出,标签) 我不知道我为什么要面对这个错误

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_accuracies = []
train_loss = []
predictions = []
for epoch in range(5):
    iterations = 0
    running_loss = 0
    for i,(inputs,labels) in enumerate(train_loader):

        iterations+=1

        inputs = inputs.float()
        labels = labels.long()

        # Feed Forward
        output = net(inputs)
        # Loss Calculation
        loss = criterion(output, labels)

        running_loss = running_loss + loss.item()
        #running_loss = running_loss + loss.tolist() 
        _, prd = torch.max(output, dim = 1)
        predictions.append(prd.item())
        #predictions.extend(prd.tolist())
        accuracy = (prd == labels).float().mean()
        train_accuracies.append(accuracy.item())
        #train_accuracies.append(accuracy.tolist())
        train_loss.append(running_loss / iterations)

        #i = i.view(i.shape[0], -1)

        # Clear the gradient buffer (we don't want to accumulate gradients)
        optimizer.zero_grad()
        # Backpropagation 
        loss.backward()
        # Weight Update: w <-- w - lr * gradient
        optimizer.step()



        #print("Epoch [{}][{}/{}], Loss: {:.3f}".format(epoch, i, len(train_loader), running_loss / iterations))
        print("Epoch [{}][{}/{}], Loss: {:.3f}".format(epoch ,i , len(train_loader), running_loss))

给我看的错误是:

RuntimeError                              Traceback (most recent call last)
<ipython-input-76-4f34dec75c72> in <module>
     15         output = net(inputs)
     16         # Loss Calculation
---> 17         loss = criterion(output, labels)
     18 
     19         running_loss = running_loss + loss.item()

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at C:\w\1\s\tmp_conda_3.7_055457\conda\conda-bld\pytorch_1565416617654\work\aten\src\THNN/generic/ClassNLLCriterion.c:94

Any idea about this?


Tags: outputlabels错误traintorchitemrunningoptimizer