Pytorch:forward()接受1个位置参数,但给出了两个

2024-05-23 22:57:48 发布

您现在位置:Python中文网/ 问答频道 /正文

这是我的第三个帖子,请宽容我。 我已经创建了一些PyTorch代码来从.csv文件中对猫和狗进行分类。那部分有效。说到问题的实质,当运行模型作为进一步开发的测试时,conda抛出了一个错误,如下所示:

Traceback (most recent call last):
  File "c:\Users\bala006\OneDrive - St John's Anglican College\Desktop\Personal\Dorime\CatsDogs\cats_dogs.py", line 70, in <module>
    run(loss_func, dataloader, model)
  File "c:\Users\bala006\OneDrive - St John's Anglican College\Desktop\Personal\Dorime\CatsDogs\cats_dogs.py", line 55, in run        
    Pred = model(imagepre)
  File "C:\Users\bala006\Miniconda3.8\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 1 positional argument but 2 were given

下面是代码

class NN (nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.flatten = nn.Flatten()
        self.stack = nn.Sequential(
            nn.Linear(input_size, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 150),
            nn.ReLU(),
            nn.Linear(150,100),
            nn.ReLU(),
            nn.Linear(100,75),
            nn.ReLU(),
            nn.Linear(75, output_size),
            nn.ReLU()
        )
    def forward (x):
        x = self.flatten(x)
        logits = self.stack(x)
        return logits

model = NN(256,2)
y_label, imagepre = dataset.__getitem__(1)
print (y_label)
print (imagepre)

loss_func = nn.CrossEntropyLoss()
optimiser = optim.SGD(model.parameters(), lr = 0.001)
epochs = 2

def run (loss_fn, dataloader, model):
    size = dataset.__len__()
    for y_label in enumerate(dataloader):
        Pred = model(imagepre)
        print (Pred)

        # Loss
        loss = loss_fn(pred, y_label)

        # Backprop
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
        print (loss)

for t in range(epochs):
    print ("Epoch #######################")
    run(loss_func, dataloader, model)

print ("Done")

经过反思,似乎出现了一些错误,但这一次的python错误消息并不太宽容。而且,我刚刚开始学习Pytork,开始了解发生了什么。一个不必回答的附带问题是,这是一个简洁高效的代码。我知道模型结构有点长,但是如果没有函数,代码似乎运行得很好。请帮助我,一个进入这个令人兴奋的领域的新手。 另外,很抱歉给您带来任何困惑,谢谢您的帮助


Tags: run代码inselfsizemodel错误nn