获取值错误:设置具有来自的序列的数组元素tf.contrib.keras公司.预处理.image.ImageDatagenerator.

2024-03-28 12:53:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试在Tensorflow中进行数据扩充。我写了这个代码。在

import numpy as np
import tensorflow as tf
import tensorflow.contrib.keras as keras
import time, random

def get_image_data_generator():
    return keras.preprocessing.image.ImageDataGenerator(
    rotation_range=get_random_rotation_angle(),\
    width_shift_range=get_random_wh_shift(),\
    height_shift_range=get_random_wh_shift(),\
    shear_range=get_random_shear(),\
    zoom_range=get_random_zoom(),\
    horizontal_flip=get_random_flip(),\
    vertical_flip=get_random_flip(),\
    preprocessing_function=get_random_function())

def augment_data(image_array,label_array):
    print image_array.shape
    images_array = image_array.copy()
    labels_array = label_array.copy()
    #Create a list of various datagenerators with different arguments
    datagenerators = []
    ndg = 10
    #Creating 10 different generators
    for ndata in xrange(ndg):
        datagenerators.append(get_image_data_generator())
    #Setting batch_size to be equal to no.of images
    bsize = image_array.shape[0]
    print bsize
    #Obtaining the augmented data
    for dgen in datagenerators:
        dgen.fit(image_array)
        (aug_img,aug_label) = dgen.flow(image_array,label_array,batch_size=bsize,shuffle=True)
        print aug_img.shape
        #Concatenating with the original data
        images_array = np.concatenate([images_array,aug_img],axis=0)
        labels_array = np.concatenate([labels_array,aug_label],axis=0)
    return (images_array,labels_array)

当我使用

augment_data(image_array,label_array)

我得到一个错误

^{pr2}$

Edit::即使我将一个图像作为参数传递,也会出现此错误。在

我做错什么了?我不明白。请帮忙。在


Tags: imageimportdatagetlabelsshiftasrange
1条回答
网友
1楼 · 发布于 2024-03-28 12:53:12

Edit :: I am getting this error even if I pass a single image as argument.`

您能否将单个元素作为数组传递并看到:

示例:

image_array, label_array = augment_data([image], [label])

相关问题 更多 >