火把形状不匹配

2024-04-29 11:32:28 发布

您现在位置:Python中文网/ 问答频道 /正文

很遗憾,我遇到以下运行时错误:

enter image description here

该错误出现在最后一批中的第1个历元(因此所有其他批次都会运行), 我不知道是什么原因导致我的代码中出现错误。下面是我的函数的代码片段

def gradient_penalty(critic, real, fake, device):

    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand(size = (BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)

    # generate tensor filles only with ones
    x = torch.ones(size = (BATCH_SIZE, C, H, W), dtype = int)

    # interpolate images
    interpolated_images = real * epsilon + fake * (x - epsilon)

变量real代表图像,其形状为(128, 3, 64, 64)。 我需要承认,我没有找到具体的错误消息,我。E张量的形状哪里不重合

任何帮助都将不胜感激


Tags: 代码sizedevice错误batchones原因torch
1条回答
网友
1楼 · 发布于 2024-04-29 11:32:28

使用drop_last参数实例化^{}时,可以放弃未完成的批处理:

torch.utils.data.DataLoader(trainset, batch_size=128, discard_last=True)

然而,这似乎有点激进,因为数据集中的128元素将被浪费

相关问题 更多 >