混淆矩阵只取0类和1类

2024-04-19 17:16:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我建立了下面的LSTM网络,它工作得很好,尽管它的精确度只有60%。我认为这是由于这个问题,它只是分类标签0和1,而不是2和3,因为混淆矩阵有零类2和3。你知道吗

import keras 
import numpy as np
from keras.preprocessing.text import Tokenizer
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Dense, Dropout, Embedding, LSTM, Flatten
from keras.models import Model
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
plt.style.use('ggplot')
%matplotlib inline
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

data = pd.read_csv("dataset/train_set.csv", sep="\t")


data['num_words'] = data.Text.apply(lambda x : len(x.split()))


num_class = len(np.unique(data.Label.values)) # 4
y = data['Label'].values


MAX_LEN = 300
tokenizer = Tokenizer()
tokenizer.fit_on_texts(data.Text.values)


post_seq = tokenizer.texts_to_sequences(data.Text.values)
post_seq_padded = pad_sequences(post_seq, maxlen=MAX_LEN)


X_train, X_test, y_train, y_test = train_test_split(post_seq_padded, y, test_size=0.25)


vocab_size = len(tokenizer.word_index) +1 


inputs = Input(shape=(MAX_LEN, ))
embedding_layer = Embedding(vocab_size,
                            128,
                            input_length=MAX_LEN)(inputs)

x = LSTM(64)(embedding_layer)
x = Dense(32, activation='relu')(x)
predictions = Dense(num_class, activation='softmax')(x)
model = Model(inputs=[inputs], outputs=predictions)
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['acc'])

model.summary()

filepath="weights.hdf5"
checkpointer = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
history = model.fit([X_train], batch_size=64, y=to_categorical(y_train), verbose=1, validation_split=0.25, 
          shuffle=True, epochs=10, callbacks=[checkpointer])

df = pd.DataFrame({'epochs':history.epoch, 'accuracy': history.history['acc'], 'validation_accuracy': history.history['val_acc']})
g = sns.pointplot(x="epochs", y="accuracy", data=df, fit_reg=False)
g = sns.pointplot(x="epochs", y="validation_accuracy", data=df, fit_reg=False, color='green')

model.load_weights('weights.hdf5')
predicted = model.predict(X_test)

predicted = np.argmax(predicted, axis=1)

accuracy_score(y_test, predicted)

print(accuracy_score)

y_pred1 = model.predict(X_test, verbose=0)
yhat_classes = np.argmax(y_pred1,axis=1)
# predict probabilities for test set
yhat_probs = model.predict(X_test, verbose=0)
# reduce to 1d array
yhat_probs = yhat_probs[:, 0]
yhat_classes = yhat_classes[:, ]

# accuracy: (tp + tn) / (p + n)
accuracy = accuracy_score(y_test, yhat_classes)
print('Accuracy: %f' % accuracy)
# precision tp / (tp + fp)
precision = precision_score(y_test, yhat_classes, average='micro')
print('Precision: %f' % precision)
# recall: tp / (tp + fn)
recall = recall_score(y_test, yhat_classes, average='micro')
print('Recall: %f' % recall)
# f1: 2 tp / (2 tp + fp + fn)
f1 = f1_score(y_test, yhat_classes, average='micro')
print('F1 score: %f' % f1)
matrix = confusion_matrix(y_test, yhat_classes) 
print(matrix)

混淆矩阵:

[[324 146   0   0]
 [109 221   0   0]
 [ 55  34   0   0]
 [ 50  16   0   0]]

平均值设置为“micro”,输出层有四个节点用于四个类。 准确度、f1分数和召回率仅来自列车组(有时预测2级,但3级不预测一次):

Accuracy: 0.888539
Precision: 0.888539
Recall: 0.888539

有人知道为什么会这样吗?你知道吗


Tags: fromtestimportdatamodeltrainsklearnclasses
1条回答
网友
1楼 · 发布于 2024-04-19 17:16:07

可能是模型陷入了次优解。在您的问题中,类0和类1占总实例的85%,因此它是相当不平衡的。该模型预测0类和1类是因为它没有完全收敛,这是这类模型中的一个经典误差模式。在一个非正式的方式,你可以认为它像模型是懒惰的。。。我建议您:

  • 训练时间更长
  • 试着看看你的模型是否能过拟合你的训练数据。为此,我会训练更长的时间并测量训练误差。您将看到,如果您的模型或数据中没有重大问题,那么该模型最终将至少在您的训练集中预测第2类和第3类。从这一点上你可以放弃你的数据/模型有问题
  • 使用批处理规范化,在实践中我看到它有助于摆脱这种错误模式
  • 总是使用一点辍学,这有助于规范化模型。你知道吗

相关问题 更多 >