Pythorchlinging train_数据加载器的数据不足

2024-05-14 18:06:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我开始使用pytorch lightning,但我的自定义数据加载程序遇到了一个问题:

Im使用自己的数据集和公共torch.utils.data.DataLoader。基本上,数据集采用一条路径并加载与数据加载器加载的给定索引相对应的数据

def train_dataloader(self):
    train_set = TextKeypointsDataset(parameters...)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
    return train_loader 

当我使用pytorch lightning模块train_dataloadertraining_step时,一切都运行良好。当我添加val_dataloadervalidation_step时,我面临这个错误:

Epoch 1:  45%|████▌     | 10/22 [00:02<00:03,  3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)

在本例中,我的数据集非常小(用于测试功能),共有84个样本,我的批量大小为8。用于培训和验证的数据集具有相同的长度(仅用于再次测试)

总的来说,它的84*2=168和168/8(batchsize)=21,这大致就是上面显示的总步骤(22)。这意味着在训练数据集上运行10次(10*8=80)后,加载程序期望新的完整样本为8,但由于只有84个样本,我得到了一个错误(至少这是我目前的理解)

我在自己的实现中遇到了类似的问题(不使用pytorch-lighntning),并使用此模式来解决它。基本上,我是在数据耗尽时重置迭代器:

try:
    data = next(data_iterator)
    source_tensor = data[0]
    target_tensor = data[1]

except StopIteration:  # reinitialize data loader if num_iteration > amount of data
    data_iterator = iter(data_loader)

现在我似乎面临着类似的问题?当我的训练数据加载程序的数据不足时,我不知道如何在pytorch lightning中重置/重新初始化数据加载程序。我想一定有另一种我不熟悉的复杂方式。多谢各位


Tags: 数据程序datasizebatchtrainutilstorch

热门问题