从数据生成器中随机裁剪图像

2024-06-08 02:43:40 发布

您现在位置:Python中文网/ 问答频道 /正文

我发现了这个代码做随机裁剪的图片为tensorflow后端。在

def random_crop(img, random_crop_size)
    # Note: image_data_format is 'channel_last'
    assert img.shape[2] == 3
    height, width = img.shape[0], img.shape[2]
    dy, dx = random_crop_size
    x = np.random.randint(0, width - dx + 1)
    y = np.random.randint(0, height - dy + 1)
    return img[y:(y+dy), x:(x+dx), :]
def crop_generator(batches, crop_length):
    while True:
        batch_x, batch_y = next(batches)
        batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
        for i in range(batch_x.shape[0]):
            batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
        yield (batch_crops, batch_y)

在我的例子中,我希望将形状为(256,256,3)的图像裁剪成(224,224,3)。在imagenes_trigo中,每个项目都是一个包含图像路径的字符串。我用class PantasTrigodataset生成了我的数据。如何使用上面的图像代码裁剪data_traindata_valid的每个图像?如果有人帮我就太好了。谢谢。在

^{pr2}$

我试过了:

train_crops = crop_generator(data_train, crop_length)
valid_crops = crop_generator(data_val, crop_length)

但当我适合这样的模型时:

model.fit_generator(
                train_crops,
                steps_per_epoch= 2245,
                epochs=500,
                validation_data=valid_crops,
                validation_steps=748,
                callbacks=[csv_logger,checkpoint])

我得到这个错误:

File "/home/jokin/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 658, in _data_generator_task
generator_output = next(self._generator)

File "/home/jokin/PycharmProjects/TFG/PLANTS DISEASE/plants.py", line 106, in crop_generator
batch_x, batch_y = next(batches)
TypeError: 'PlantasTrigoDataset' object is not an iterator

Tags: 图像cropscropimgdatanpbatchbatches

热门问题