当有多个输出时,如何修正ValueError(x和y应该具有相同的长度)?

2024-04-26 03:52:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在建立一个模型,它有一个图像输入(130130,1)和3个输出,每个输出包含一个(10,1)向量,其中softmax单独应用。你知道吗

(Inspired by J. Goodfellow, Yaroslav Bulatov, Julian Ibarz, Sacha Arnoud, and Vinay D. Shet. Multi-digit number recognition from street view imagery using deep convolutional neural networks. CoRR, abs/1312.6082, 2013. URL http://arxiv.org/abs/1312.6082 , sadly they didn't publish their network).

input = keras.layers.Input(shape=(130,130, 1)
l0 = keras.layers.Conv2D(32, (5, 5), padding="same")(input)
[conv-blocks etc]
l12 = keras.layers.Flatten()(l11)
l13 = keras.layers.Dense(4096, activation="relu")(l12)
l14 = keras.layers.Dense(4096, activation="relu")(l13)
output1 = keras.layers.Dense(10, activation="softmax")(l14)
output2 = keras.layers.Dense(10, activation="softmax")(l14)
output3 = keras.layers.Dense(10, activation="softmax")(l14)

model = keras.models.Model(inputs=input, outputs=[output1, output2, output3])
model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy', 
              'categorical_crossentropy'],
              loss_weights=[1., 1., 1.],
              optimizer=optimizer,
              metrics=['accuracy'])

train_generator = train_datagen.flow(x_train,
              [[y_train[:, 0, :], y_train[:, 1, :], y_train[:, 2, :]], 
              batch_size=batch_size)

但是我得到:值错误:x(图像张量)和y(标签)应该有相同的长度。发现:x.shape=(1000,130,130,1),y.shape=(3,1000,10)

但如果我把它改成:

 [same as before]
 train_generator = train_datagen.flow(x_train,
              y_train, 
              batch_size=batch_size)

然后我得到:ValueError:检查模型目标时出错:传递给模型的Numpy数组列表不是模型所期望的大小。预计将看到3个阵列

  • 尺寸(x×U列)=(1000、130、130、1)
    • 其中每个图像是(130,130,1),有1000个图像
  • 尺寸(y\列车)=(1000、3、10)

documentation中,它是这样说的

model = Model(inputs=[main_input, auxiliary_input], outputs= 
[main_output, auxiliary_output])

然而,我不知道你怎么能有相同的长度输出和输入?你知道吗


Tags: 模型图像inputsizemodellayersbatchtrain
1条回答
网友
1楼 · 发布于 2024-04-26 03:52:33

感谢@Djib2011。当我在文档中查找示例以便在字典中传递它时,我注意到所有示例都使用model.fit(),而不是model.fit_generator()。你知道吗

所以我做了研究,发现还有一个bug(从2016年开始开放!)用于具有单输入和多输出的ImageDataGenerator。 悲惨的故事。你知道吗

所以解决方法是使用model.fit()而不是model.fit_generator()。你知道吗

相关问题 更多 >