如何解决这个错误:期望flatten\u输入有3个维度,但是得到了形状为(1,28,28,3)的数组?

2024-04-23 19:04:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图让我的代码识别数字使用tensorflow为我的学校项目。但我总是犯这个错误。有人能帮我吗?非常感谢你!你知道吗

试着像压平,改变大小等,但没有有效的。。。你知道吗

这是我的密码:

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([

tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1)
model.evaluate(x_test, y_test)
# Part 3 - Making new predictions
import numpy as np
from keras.preprocessing import image
import keras
test_image = image.load_img('Number 8.jpg', target_size=(28, 28))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
result = model.predict(test_image)
print(np.argmax(result[0]))

应为3的数组


Tags: testimageimportmodellayerstftensorflowas
1条回答
网友
1楼 · 发布于 2024-04-23 19:04:59

我想是从这条线来的

tf.keras.layers.Flatten(input_shape=(28,28)),

你可以用

tf.keras.layers.Flatten()

即使你的图像是(28,28),训练时也会有一个批量维度[batch_size, 28,28]。因为在model.fit中没有传递批大小,所以使用默认值。你知道吗

相关问题 更多 >