MNIST数字分类仅适用于[5,6,7,8,9]范围内的数字

2024-06-16 17:56:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图编写一个简单的代码来分类MNIST数据集,但只针对数字5、6、7、8、9。我写了下面的代码。在跑步的时候,我得到了0的准确度和nan的损失。我试着用数字0,1,2,3,4做同样的处理,它在训练数据上的准确率几乎达到99.4%(只是把下面的训练掩码和测试掩码改为[0,1,2,3,4])。有人能帮我理解为什么代码对5,6,7,8,9范围内的数字分类没有任何作用?提前感谢您的帮助!在

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Dropout, Flatten, MaxPooling2D

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

num_digits_to_classify = 5

train_mask59 = np.isin(y_train,[5,6,7,8,9])
test_mask59 = np.isin(y_test,[5,6,7,8,9])

x_train59, y_train59 = x_train[train_mask59], y_train[train_mask59]
x_test59, y_test59 = x_test[test_mask59], y_test[test_mask59]

# Reshaping the array to 4-dims so that it can work with the Keras API
x_train59 = x_train59.reshape(x_train59.shape[0], 28, 28, 1)
x_test59 = x_test59.reshape(x_test59.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
# Making sure that the values are float so that we can get decimal points after division
x_train59 = x_train59.astype('float32')
x_test59 = x_test59.astype('float32')
# Normalizing the RGB codes by dividing it to the max RGB value.
x_train59 /= 255
x_test59 /= 255


checkpoint_path = "D:/home/work/Fast_Learning/training/cp.ckpt59"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

def create_model():
    model = Sequential()
    model.add(Conv2D(28, kernel_size=(3,3), activation='relu', input_shape=input_shape))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten()) # Flattening the 2D arrays for fully connected layers
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(num_digits_to_classify,activation='softmax'))

    model.compile(optimizer='adam', 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    return model

model = create_model()

model.fit(x=x_train59,y=y_train59, epochs=20,callbacks = [cp_callback])

^{pr2}$

Tags: thetopath代码testimportaddmodel
1条回答
网友
1楼 · 发布于 2024-06-16 17:56:00

如果使用sparse_categorical_crossentropy,则目标标签应在[0, num_digits_to_classify)范围内。当选择数字0到4时,情况就是这样,但是当选择5到9时,它们将被5偏移。因此,在拟合之前,应调整目标标签:

y_train59 -= 5
y_test59 -= 5

顺便说一下,如果您使用tensorflow.keras模块,如下所示:

^{pr2}$

它应该生成一个相应的错误消息(不确定keras是否也这样做):

tensorflow.python.framework.errors_impl.InvalidArgumentError: Received a label value of 9 which is outside the valid range of [0, 5).

对于任意数量的元素

有多个选项仅适合MNIST数字0-9的一个子集:

  1. 将最后一个Dense层中的节点数设置为10(并且只适合较少的数量)
  2. 将实际使用的数字(例如[0, 1, 5, 6, 8])压缩到[0, N)范围内。在

对于案例1)

^{4}$

对于案例2)

# Transform y_train (and similarly y_test).
unique, index = np.unique(y_train, return_inverse=True)
y_train = np.arange(len(unique))[index]
# To get back the original labels, just index into the unique values.
unique[y_train]

注意,sparse_categorical_crossentropy用于标签是整数的情况,categorical_crossentropy用于这些标签是一个热编码的情况。例如:

sparse_categorical_crossentropy: y = [0, 2, 1, 1, 2, 0]
       categorical_crossentropy: y = [[1, 0, 0],
                                      [0, 0, 1],
                                      [0, 1, 0],
                                      [0, 1, 0],
                                      [0, 0, 1],
                                      [1, 0, 0]]

相关问题 更多 >