我试图编写一个简单的代码来分类MNIST数据集,但只针对数字5、6、7、8、9。我写了下面的代码。在跑步的时候,我得到了0的准确度和nan的损失。我试着用数字0,1,2,3,4做同样的处理,它在训练数据上的准确率几乎达到99.4%(只是把下面的训练掩码和测试掩码改为[0,1,2,3,4])。有人能帮我理解为什么代码对5,6,7,8,9范围内的数字分类没有任何作用?提前感谢您的帮助!在
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Dropout, Flatten, MaxPooling2D
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
num_digits_to_classify = 5
train_mask59 = np.isin(y_train,[5,6,7,8,9])
test_mask59 = np.isin(y_test,[5,6,7,8,9])
x_train59, y_train59 = x_train[train_mask59], y_train[train_mask59]
x_test59, y_test59 = x_test[test_mask59], y_test[test_mask59]
# Reshaping the array to 4-dims so that it can work with the Keras API
x_train59 = x_train59.reshape(x_train59.shape[0], 28, 28, 1)
x_test59 = x_test59.reshape(x_test59.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
# Making sure that the values are float so that we can get decimal points after division
x_train59 = x_train59.astype('float32')
x_test59 = x_test59.astype('float32')
# Normalizing the RGB codes by dividing it to the max RGB value.
x_train59 /= 255
x_test59 /= 255
checkpoint_path = "D:/home/work/Fast_Learning/training/cp.ckpt59"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1)
def create_model():
model = Sequential()
model.add(Conv2D(28, kernel_size=(3,3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten()) # Flattening the 2D arrays for fully connected layers
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_digits_to_classify,activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
model = create_model()
model.fit(x=x_train59,y=y_train59, epochs=20,callbacks = [cp_callback])
^{pr2}$
如果使用
sparse_categorical_crossentropy
,则目标标签应在[0, num_digits_to_classify)
范围内。当选择数字0到4时,情况就是这样,但是当选择5到9时,它们将被5偏移。因此,在拟合之前,应调整目标标签:顺便说一下,如果您使用
^{pr2}$tensorflow.keras
模块,如下所示:它应该生成一个相应的错误消息(不确定
keras
是否也这样做):对于任意数量的元素
有多个选项仅适合MNIST数字0-9的一个子集:
Dense
层中的节点数设置为10(并且只适合较少的数量)[0, 1, 5, 6, 8]
)压缩到[0, N)
范围内。在对于案例1)
^{4}$对于案例2)
注意,
sparse_categorical_crossentropy
用于标签是整数的情况,categorical_crossentropy
用于这些标签是一个热编码的情况。例如:相关问题 更多 >
编程相关推荐