PyTorch:如何对现有数据集应用另一个转换?

2024-04-18 16:59:19 发布

您现在位置:Python中文网/ 问答频道 /正文

这是一个代码示例:

dataset = datasets.MNIST(root=root, train=istrain, transform=None)  #preserve raw img

print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>

dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample

for ind in range(len(dataset)):
    img, label = dataset[ind] # <class 'PIL.Image.Image'> <class 'int'>/<class 'numpy.int64'>
    img.save(fp=os.path.join(saverawdir, f'{ind:02d}-{int(label):02d}.png'))

dataset.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
#transform for net forwarding

print(type(dataset[0][0]))
# expected <class 'torch.Tensor'>, however it's still <class 'PIL.Image.Image'>

由于数据集是随机重新采样的,所以我不想重新加载带有transform的新数据集,而只需将transform应用于已经存在的数据集

谢谢你的帮助:D


Tags: 数据imageimgforpiltypetransformroot