在这里发布之前,我试图在网上深入搜索解决方案,但我找不到。我的问题是在卷积神经网络训练中读取图像时出现的。基本上,我决定创建一个函数,从一系列图像中创建in值和out值。我希望读取集合的所有图像,但不是同时读取所有图像,以避免内存不足,因此我创建了下一个函数:
def readImages(strSet = 'Train', nIni = 1, nFin = 20):
if strSet not in ('Train','Test'):
return None
#
# Inicializamos los arrays de salida: las imágenes y las etiquetas.
arrImages = []
arrLabels = []
#
# Recorremos todos y cada uno de los directorios dentro del set elegido
for strDir in os.listdir(data_dir+'/' + strSet + '/'):
# Nombre de la clase que estamos tratando.
strClass = strDir[strDir.find('-')+1:]
# Número y nombre de los ficheros, por si es menor que el número n indicado.
arrNameFiles = os.listdir(data_dir+'/' + strSet + '/'+strDir)
nFiles = len(os.listdir(data_dir+'/' + strSet + '/'+strDir))
#
# Cogemos los ficheros desde el nIni al nFin. De esta forma nos aseguramos los cogemos todos en cada directorio.
#print('nImagesClase(',strSet,',',strClass,'):',nImagesClase(strSet, strClass))
if (nIni == -1):
# Si el valor es -1, cogemos todas las imágenes del directorio.
listChosenFiles = arrNameFiles
#print('Todos: ', len(listChosenFiles))
else:
if (nImagesClase(strSet, strClass)<nFin):
# Si ya hemos dado la vuelta a todos los ficheros del grupo, los cogemos al azar.
listChosenFiles = random.sample(arrNameFiles, min(nFiles, nFin-nIni))
#print('Fin del directorio ',nFin,'>',nImagesClase(strSet,strClass),': ', len(listChosenFiles))
else:
# Si no, seguimos.
listChosenFiles = arrNameFiles[nIni-1:min(nFin,nImagesClase(strSet, strClass))-1]
#print('Seguimos ',nIni,'-',nFin,': ', len(listChosenFiles))
#
for file in listChosenFiles:
# Lectura del fichero.
image = plt.imread(data_dir+'/'+strSet+'/'+strDir+'/'+file)
#print('Original Shape: ',image.shape)
#plt.imshow(image)
image = cv2.resize(image, (crop_width, crop_height), interpolation=cv2.INTER_NEAREST)
#image = image.reshape((image_height,image_width,num_channels))
#print('Al array de imágenes: ',image.shape)
arrImages.append(image)
# Añadimos etiquetas.
arrLabel = np.zeros(n_classes)
arrLabel[array_classes.index(strClass)] = 1
arrLabels.append(arrLabel)
#
# Recogemos los valores de entrada y salida en arrays.
y = np.array(arrLabels)
X = np.array(arrImages, dtype=np.uint8)
# Una vez terminado el recorrido por todas las imágenes, reordenamos los índices para que no vayan las imágenes en secuendias de la misma clase.
arrIndexes = np.arange(X.shape[0])
np.random.shuffle(arrIndexes)
X = X[arrIndexes]
y = y[arrIndexes]
#
return X, y
为了测试这个函数的行为,我只执行下面的代码行
X, y = readImages(strSet = 'Train', nIni = 1, nFin = 5)
这是正常的,直到nIni和nFin达到某些值(例如101-105)。在那一刻,我收到了以下错误:
ValueError Traceback (most recent call last)
<ipython-input-125-8a690256a1fc> in <module>
----> 1 X, y = readImages(strSet = 'Train', nIni = 101, nFin = 105)
<ipython-input-123-9e9ebc660c33> in readImages(strSet, nIni, nFin)
50 # Recogemos los valores de entrada y salida en arrays.
51 y = np.array(arrLabels)
---> 52 X = np.array(arrImages, dtype=np.uint8)
53 # Una vez terminado el recorrido por todas las imágenes, reordenamos los índices para que no vayan las imágenes en secuendias de la misma clase.
54 arrIndexes = np.arange(X.shape[0])
ValueError: could not broadcast input array from shape (28,28,3) into shape (28,28)
我在读取图像时放了一些打印痕迹,每个读取的图像的形状都是(28,28,3),所以我真的不明白错误痕迹中指出的(28,28)形状是从哪里来的
你知道有什么问题吗?你以前有没有遇到过这个问题
提前谢谢
您的某些图像具有单通道。用
cv2.imread
代替plt.imread
image = cv2.imread(data_dir+'/'+strSet+'/'+strDir+'/'+file)
相关问题 更多 >
编程相关推荐