如何统一分配车组

2024-04-28 03:45:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我有下面的目录结构

data/
    train/
        Cat 1/ ### 5000 pictures
            dog001.jpg

            ...
        cat 2/ ### 3000 pictures
            cat001.jpg

       Cat 3/ ### 50000 pictures
            Unicorn.jpg

            ...
        Cat 4/ ### 10000 pictures
            Angels.jpg

我使用下面的代码加载我的图片

datagen = ImageDataGenerator(rescale=1./255)

# automagically retrieve images and their classes for train and validation sets
train_generator = datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_width, img_height),
        batch_size=batch_size,
        class_mode="categorical")

因为我的数据分布不均匀,所以我的模型不适合,它变得偏向于Cat 3,所以我如何加载所有四个类别的一致的列车数据


Tags: and目录imgdatasizebatchtrain结构
2条回答

有两种方法:

  1. cat3中删除一些数据,这样就可以对数据进行统一洗牌
  2. 向其他类添加数据

1非常简单,要添加数据,您可以从其他不太频繁的类中复制数据,或者,更好的方法是从现有类中生成新数据

通过操纵图像,你可以设置一行/列为空白,你可以旋转图像或移动它,我用smth这样来实现这些效果一个28x28图像

import numpy as np
from scipy.ndimage.interpolation import rotate, shift

def rand_jitter(temp, prob=0.5):
    np.random.seed(1337)  # for reproducibility
    if np.random.random() > prob:
        temp[np.random.randint(0,28,1), :] = 0
    if np.random.random() > prob:
        temp[:, np.random.randint(0,28,1)] = 0
    if np.random.random() > prob:
        temp = shift(temp, shift=(np.random.randint(-3,4,2)))
    if np.random.random() > prob:
        temp = rotate(temp, angle = np.random.randint(-20,21,1), reshape=False)
    return temp

通过这种方法,你可以用更多的数据来训练你的网络,并对其进行推广,使其预测更加可靠

您不必删除任何数据点,并且应该保留尽可能多的数据点

为此,需要向现有的keras图像数据生成器添加一些代码,但应该很简单。这里的总体思路是提供一个自定义采样函数,根据目标类对训练数据点进行统一采样,可以分3步进行:

  1. 构建字典LUT={'class-1':[class-1 files],'class-2':[class-2 files],…,'class-k':[class-k files]}

  2. 以均匀随机的方式在LUT中选取一个键

  3. 以均匀随机的方式在LUT[key]中选取一个文件

相关问题 更多 >