不适用于数据的转换

2024-05-23 17:01:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我是Pythorch的新人,想了解一些事情。在

我正在加载MNIST,如下所示:

transform_train = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(size, interpolation=2),
     # transforms.Grayscale(num_output_channels=1),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize((mean), (std))])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

然而,当我研究数据集时,即trainloader.dataset.train_data[0],我得到的张量在[0255]范围内,形状为(28,28)。在

我错过了什么?这是因为转换不是直接应用于dataloader,而是只在运行时应用的吗?否则,我如何浏览我的数据?在


Tags: 数据truedatasizebatchtransformtrain事情
1条回答
网友
1楼 · 发布于 2024-05-23 17:01:01

当调用Dataset__getitem__方法时,将应用这些转换。例如,看看MNIST数据集类的__getitem__方法:https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62

def __getitem__(self, index):
    """
    Args:
        index (int): Index
    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], self.targets[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

当您索引训练集的MNIST实例时,__getitem__方法将被调用,例如:

^{pr2}$

有关__getitem__的详细信息:https://docs.python.org/3.6/reference/datamodel.html#object.getitem

ResizeRandomHorizontalFlip应该在ToTensor之前的原因是它们作用于PIL Images,为了一致性,Pytorch中的所有数据集首先将数据加载为PIL Image。事实上,你可以看到,在这里他们强迫这种行为通过:

img = Image.fromarray(img.numpy(), mode='L')

一旦您有了相应索引的PIL Image,转换将应用于

if self.transform is not None:
    img = self.transform(img)

ToTensorPIL Image转换为torch.Tensor,并且Normalize减去平均值并除以您提供的标准差。在

最终,一些变换将应用于带有

if self.target_transform is not None:
    target = self.target_transform(target)

最后返回处理后的图像和处理后的标签。所有这些都发生在一个trainset[key]调用中。在

import torch
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

transform_train = Compose([Resize(28, interpolation=2),
                           RandomHorizontalFlip(p=0.5),
                           ToTensor(),
                           Normalize([0.], [1.])])

trainset = MNIST(root='./data', train=True, download=True,
                 transform=transform_train)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())

显示

(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))

相关问题 更多 >