我正在尝试制作一个图像分类器,将图像分为层次类(粗糙标签1、粗糙标签2、精细标签)。数据存储在一个csv文件中,其中包含image_path, coarse_label1, coarse_label2, fine_label
列。我试图使用torch.data.utils中的Dataset和DataLoader类来创建数据生成器,以便将其提供给我的模型
我阅读了一些教程(Tutorial1和Tutorial2),并修改了dataset类,如下所示:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from imgaug import augmenters as iaa
class WDataset(Dataset):
def __init__(self, img_path, coarse1, coarse2, fine, train_mode, shape):
"""
Args:
img_path (list): List of image paths.
coarse1, coarse2, fine: List of coarse and fine labels.
"""
self.img_path = img_path
self.c1_label_list = coarse1
self.c2_label_list = coarse2
self.fine_label_list = fine
self.train_mode = train_mode
self.shape = shape
def __len__(self):
return len(self.fine_label_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
augmentor = self._get_augmentor()
image = self.pil_loader(self.img_path[idx])
image = np.array(image.resize(self.shape))
output = self.fine_label_list[idx]
output_c1 = self.c1_label_list[idx]
output_c2 = self.c2_label_list[idx]
if self.train_mode:
image = augmentor.augment_image(image)
return image, output_c1, output_c2, output
def pil_loader(self, path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def _get_augmentor(self):
arg_list = []
arg_list.append(iaa.Affine(shear=(-5, 5), name="shear"))
arg_list.append(iaa.Affine(rotate=(-25, 25), name="rotate"))
arg_list.append(iaa.Dropout((0, 0.3), name="gomashio"))
arg_list.append(iaa.contrast.LinearContrast((0.7, 1.3), name="contrast"))
arg_list.append(iaa.Multiply((0.7, 1.3), name="darken_lighten"))
seq = iaa.SomeOf((0, 3), arg_list, random_order=True)
return seq
使用这个类,我能够生成一个数据集,并使用生成的数据集的索引访问数据。但是,当我创建一个DataLoader并尝试从中获取一批数据时,任务不会终止。换句话说
train_ds = WDataset(X[train_index], y_c1[train_index], y_c2[train_index], y_fine[train_index], train_mode=True, shape=(height, width))
image, coarse1, coarse2, fine = train_ds[index]
一切正常。 但是下面的代码一直在处理,没有给出任何警告、错误或输出
# Parameters
params = {'batch_size': 4,
'shuffle': True,
'num_workers': 2}
# Generators
train_generator = torch.utils.data.DataLoader(train_ds, **params)
next(iter(train_generator))
我想知道我写WDataset的方式是否正确? 我还应该如何修改它以创建数据加载器? 是否有其他方法将此类(分层/多类分类)数据提供给PyTorch中的模型
编辑:
我找到了问题的根源。事实证明,此错误发生在使用python的macOS中>;=3.8当DataLoader中的num_worker设置为大于零时。下面是描述这个问题的github issue
从本期杂志的评论中可以看出:
“这是一个fork vs spawn的问题。因此多处理相关,而不是PyTorch。 在mac上的Python3.8中,现在多处理的默认后端是spawn,并且找不到默认的collate_fn。 这可以通过强制start_方法fork或在单独的模块中定义collate_fn并导入来解决。”
目前没有回答
相关问题 更多 >
编程相关推荐