我试图做一个简单的CNN模型,可以识别口袋妖怪。第一次尝试,我自己创建了一个非常小的数据集,由10个不同口袋妖怪的100张图片组成。 在Python中使用这段代码,看起来效果不错。你知道吗
import tensorflow as tf
import numpy as np
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3,3), input_shape=(200,200,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(32, (3,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units=400, activation='relu'))
model.add(tf.keras.layers.Dense(units=10, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
train = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
test = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
training_set= train.flow_from_directory('datasets/starter/train', target_size=(200,200), class_mode='categorical')
val_set= test.flow_from_directory('datasets/starter/test', target_size=(200,200), class_mode='categorical')
history=model.fit_generator(training_set, steps_per_epoch=32, epochs=3, validation_data=val_set, validation_steps=32)
test_image = tf.keras.preprocessing.image.load_img('datasets/starter/val/attempt.png', target_size=(200, 200))
test_image = tf.keras.preprocessing.image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
result = model.predict(test_image)
print(training_set.class_indices)
print(result)
test_image2 = tf.keras.preprocessing.image.load_img('datasets/starter/val/attempt2.png', target_size=(200, 200))
test_image2 = tf.keras.preprocessing.image.img_to_array(test_image2)
test_image2 = np.expand_dims(test_image2, axis=0)
result2 = model.predict(test_image2)
print(training_set.class_indices)
print(result2)
最后一个历元的训练精度固定为1。 当我尝试预测示例图像时: 尝试.png是一个Charmander图片,它的标签是1,所以我得到这个向量:[[0。100. ... 0.]] attempt2.png是一个Torchic图片,它的标签是7,所以我得到:[[0。0. ... 100]]你知道吗
但我注意到模型.编译'应该是'categorical\u crossentropy',而不是'binary\u crossentropy'。使用分类的,我的程序将不再工作。 有人能帮我理解吗?你知道吗
你应该尝试使用Softmax作为最后的激活,用分类交叉熵作为损失函数
相关问题 更多 >
编程相关推荐