Keras中的小错误:ValueError的错误修正?

2024-03-28 07:33:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我训练了一个Keras模型,但是,我很难预测它。我的输入数组的形状是(400,2),输出数组的形状是(400,1)。现在,当我将参数array([1,2])传递给model.predict()函数时,得到以下错误:

 ValueError: Error when checking input: expected dense_1_input to have shape (2,) but got array with shape (1,).

这是毫无意义的,因为shape(array([1,2])) = (2,),因此model.predict函数应该接受它作为有效输入。

相反,当我传递一个形状为(1,2)的数组时,它的效果很好。那么Keras实现中有没有bug呢?你知道吗

我的模型如下:

from keras.models import Sequential
from keras.layers import Dense
from keras import optimizers
import numpy as np

data = np.random.rand(400,2)
Y = np.random.rand(400,1)

def base():
    model = Sequential()
    model.add(Dense(4,activation = 'tanh', input_dim = 2))
    model.add(Dense(1, activation = 'sigmoid'))
    model.compile(optimizer = optimizers.RMSprop(lr = 0.1), loss = 'binary_crossentropy', metrics= ['accuracy'])      
    return model 

model = base()
model.fit(data,Y, epochs = 10, batch_size =1)
model.predict(np.array([1,2]).reshape(2,1))   #Error
model.predict(np.array([1,2]).reshape(2,)) #Error
model.predict(np.array([1,2]).reshape(1,2)) #Works

Tags: fromimportinputmodelnperror数组array
2条回答

第一个维度是批次维度。因此,即使使用predict方法,也必须传递形状为(num_samples,) + (the input shape of network)的数组。这就是为什么当你传递一个形状为(1,2)的数组时它会工作,因为(1,)表示样本数,(2,)是网络的输入形状。你知道吗

如果模型需要(2,)数组,则应传递(2,)形状数组

x = x.reshape((2,))
model.predict(x)

相关问题 更多 >