在多行上绘制数据集和标签(jupyter笔记本)

2024-03-29 15:25:56 发布

您现在位置:Python中文网/ 问答频道 /正文

下面的代码简单地绘制了数据集(由狗和猫的图像组成)及其标签。我用的是jupyter笔记本:

train_path = 'dataset/train'
valid_path = 'dataset/valid'
test_path = 'dataset/test'

train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(224,224), classes=['dog', 'cat'], batch_size=10)
valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(224,224), classes=['dog', 'cat'], batch_size=4)
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(224,224), classes=['dog', 'cat'], batch_size=10)

# plot function, used to plot images with labels within jupyter notebook
def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
    if type(ims[0]) is np.ndarray:
        ims= np.array(ims).astype(np.uint8)
        if (ims.shape[-1] != 3):
            ims = ims.transpose((0,2,3,1))
    f = plt.figure(figsize=figsize)
    cols = len(ims)//rows if len(ims) % 2 == 0 else len(ims)//rows + 1
    for i in range(len(ims)):
        sp = f.add_subplot(rows, cols, i+1)
        sp.axis('off')
        if titles is not None:
            sp.set_title(titles[i], fontsize=16)
        plt.imshow(ims[i], interpolation=None if interp else 'none')

imgs, labels = next(train_batches)

# we plot these samples of images and their labels 1 batch at a time.
plots(imgs, titles=labels)

如果我按照上面的代码,每批只使用10个样本,这就足够适合笔记本页面宽度:

但是如果我想将批大小更改为大于这个值,比如说在一个批中有100个样本(或任何大小)(即在代码train_batches = ImageDataGenerator()更改batch_size=100)中绘制这个值,它只会尝试将其全部压缩在一行中,如下面的屏幕截图所示:

我该如何在代码中更改这一点,使它能够绘制所有100个数据样本,使其充分适合多行,而不是将它们全部(压缩到很小的大小)放在一行中。在

提前致谢。在


Tags: path代码testsizelabelslenifbatch