修改Dataset类以将数据馈送到pytorch中的分支CNN?

2024-06-16 09:04:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试制作一个图像分类器,将图像分为层次类(粗糙标签1、粗糙标签2、精细标签)。数据存储在一个csv文件中,其中包含image_path, coarse_label1, coarse_label2, fine_label列。我试图使用torch.data.utils中的Dataset和DataLoader类来创建数据生成器,以便将其提供给我的模型

我阅读了一些教程(Tutorial1Tutorial2),并修改了dataset类,如下所示:

import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from imgaug import augmenters as iaa

class WDataset(Dataset):
    def __init__(self, img_path, coarse1, coarse2, fine, train_mode, shape):
        """
        Args:
            img_path (list): List of image paths.
            coarse1, coarse2, fine: List of coarse and fine labels.
        """
        self.img_path = img_path
        self.c1_label_list = coarse1
        self.c2_label_list = coarse2
        self.fine_label_list = fine
        self.train_mode = train_mode
        self.shape = shape

    def __len__(self):
        return len(self.fine_label_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        augmentor = self._get_augmentor()
        
        image = self.pil_loader(self.img_path[idx])
        image = np.array(image.resize(self.shape))
        output = self.fine_label_list[idx]
        output_c1 = self.c1_label_list[idx]
        output_c2 = self.c2_label_list[idx]
        if self.train_mode:
            image = augmentor.augment_image(image)

        return image, output_c1, output_c2, output
    
    def pil_loader(self, path: str) -> Image.Image:
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def _get_augmentor(self):
        arg_list = []
        arg_list.append(iaa.Affine(shear=(-5, 5), name="shear"))
        arg_list.append(iaa.Affine(rotate=(-25, 25), name="rotate"))
        arg_list.append(iaa.Dropout((0, 0.3), name="gomashio"))
        arg_list.append(iaa.contrast.LinearContrast((0.7, 1.3), name="contrast"))
        arg_list.append(iaa.Multiply((0.7, 1.3), name="darken_lighten"))
        seq = iaa.SomeOf((0, 3), arg_list, random_order=True)
        return seq

使用这个类,我能够生成一个数据集,并使用生成的数据集的索引访问数据。但是,当我创建一个DataLoader并尝试从中获取一批数据时,任务不会终止。换句话说

train_ds = WDataset(X[train_index], y_c1[train_index], y_c2[train_index], y_fine[train_index], train_mode=True, shape=(height, width))
image, coarse1, coarse2, fine = train_ds[index]

一切正常。 但是下面的代码一直在处理,没有给出任何警告、错误或输出

# Parameters
params = {'batch_size': 4,
          'shuffle': True,
          'num_workers': 2}
# Generators
train_generator = torch.utils.data.DataLoader(train_ds, **params)
next(iter(train_generator))

我想知道我写WDataset的方式是否正确? 我还应该如何修改它以创建数据加载器? 是否有其他方法将此类(分层/多类分类)数据提供给PyTorch中的模型

编辑:

我找到了问题的根源。事实证明,此错误发生在使用python的macOS中>;=3.8当DataLoader中的num_worker设置为大于零时。下面是描述这个问题的github issue

从本期杂志的评论中可以看出:

“这是一个fork vs spawn的问题。因此多处理相关,而不是PyTorch。 在mac上的Python3.8中,现在多处理的默认后端是spawn,并且找不到默认的collate_fn。 这可以通过强制start_方法fork或在单独的模块中定义collate_fn并导入来解决。”


Tags: 数据pathimageimportselfimgoutputarg