如何将图像加载到Pytorch数据加载器中?

2024-05-23 19:30:54 发布

您现在位置:Python中文网/ 问答频道 /正文

Pythorch的数据加载和处理教程是非常具体的一个例子,有人可以帮助我什么样的功能应该像一个更通用的简单图像加载?

教程:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

我的数据:

我将MINST数据集作为jpg保存在下面的文件夹结构中。(我知道我可以只使用dataset类,但这纯粹是为了了解如何将简单的图像加载到pytorch中,而不使用csv或复杂的特性)。

文件夹名是标签,图像是28x28png的灰度,不需要转换。

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

Tags: 数据org图像功能文件夹httpdatapng
2条回答

如果您正在使用mnist,那么pytorch中已经通过torchvision设置了一个预设值。
你可以的

import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                      shuffle=True, num_workers=2)

如果您想泛化到图像目录(与上面的导入相同),可以

class mnistmTrainingDataset(torch.utils.data.Dataset):

    def __init__(self,text_file,root_dir,transform=transformMnistm):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
        self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.name_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)
        labels = self.label_frame.iloc[idx, 0]
        #labels = labels.reshape(-1, 2)
        sample = {'image': image, 'labels': labels}

        return sample


mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                   root_dir = 'Downloads/mnist_m/mnist_m_train')

mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)

然后,您可以对其进行迭代,如下所示:

for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
    print("training sample for mnist-m")
    print(i_batch,sample_batched['image'],sample_batched['labels'])

有很多方法可以推广pytorch用于图像数据集加载,我知道的方法是子类化torch.utils.data.dataset

下面是我为Pythorch0.4.1所做的(应该仍然在1.3中工作)

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network

相关问题 更多 >