如何在Pytorch数据集上应用OpenCV过滤器?

2024-04-20 08:38:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用以下步骤使用OpenCV预处理单个图像。现在,我想在Pytorch中训练模型之前,将这些预处理步骤应用于我的整个数据集。如何做到这一点

im = cv2.imread(image_path)
im_nonoise = cv2.medianBlur(im, 3)
imgray = cv2.cvtColor(im_nonoise,cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cl1 = clahe.apply(imgray)
ret,thresh = cv2.threshold(cl1,110,255,0)
image, contours, hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
img = cv2.drawContours(image, contours, -1, (250,100,120))

我使用

data = datasets.ImageFolder(train_dir,transform=transform)
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size,sampler=train_sampler)

Tags: imagedatasizebatch步骤transformtraincv2
1条回答
网友
1楼 · 发布于 2024-04-20 08:38:20

您可以构建自己的dataset类(从ImageFolder派生)并仅重载__getitem__方法:

class MySpecialDataset(datasets.ImageFolder):
  def __init__(self, root, loader=default_loader, is_valid_file=None):
    super(MySpecialDataset, self).__init__(root=root, loader=loader, is_valid_file=is_valid_file)

  def __getitem__(self, index):
    image_path, target = self.samples[index]
    # do your magic here
    im = cv2.imread(image_path)
    im_nonoise = cv2.medianBlur(im, 3)
    imgray = cv2.cvtColor(im_nonoise,cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl1 = clahe.apply(imgray)
    ret,thresh = cv2.threshold(cl1,110,255,0)
    image, contours, hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    img = cv2.drawContours(image, contours, -1, (250,100,120))
    # you need to convert img from np.array to torch.tensor
    # this has to be done CAREFULLY!
    sample = torchvision.transforms.ToTensor()(img)
    return sample, target

拥有此数据集后,可以将其与基本pytorch的DataLoader一起使用:

data = MySpecialDataset(train_dir)
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size,sampler=train_sampler)

相关问题 更多 >