我正在尝试在Tensorflow中进行数据扩充。我写了这个代码。在
import numpy as np
import tensorflow as tf
import tensorflow.contrib.keras as keras
import time, random
def get_image_data_generator():
return keras.preprocessing.image.ImageDataGenerator(
rotation_range=get_random_rotation_angle(),\
width_shift_range=get_random_wh_shift(),\
height_shift_range=get_random_wh_shift(),\
shear_range=get_random_shear(),\
zoom_range=get_random_zoom(),\
horizontal_flip=get_random_flip(),\
vertical_flip=get_random_flip(),\
preprocessing_function=get_random_function())
def augment_data(image_array,label_array):
print image_array.shape
images_array = image_array.copy()
labels_array = label_array.copy()
#Create a list of various datagenerators with different arguments
datagenerators = []
ndg = 10
#Creating 10 different generators
for ndata in xrange(ndg):
datagenerators.append(get_image_data_generator())
#Setting batch_size to be equal to no.of images
bsize = image_array.shape[0]
print bsize
#Obtaining the augmented data
for dgen in datagenerators:
dgen.fit(image_array)
(aug_img,aug_label) = dgen.flow(image_array,label_array,batch_size=bsize,shuffle=True)
print aug_img.shape
#Concatenating with the original data
images_array = np.concatenate([images_array,aug_img],axis=0)
labels_array = np.concatenate([labels_array,aug_label],axis=0)
return (images_array,labels_array)
当我使用
augment_data(image_array,label_array)
我得到一个错误
^{pr2}$Edit::即使我将一个图像作为参数传递,也会出现此错误。在
我做错什么了?我不明白。请帮忙。在
您能否将单个元素作为数组传递并看到:
示例:
相关问题 更多 >
编程相关推荐