如何解读Scikit-learn混淆矩阵

1 投票
1 回答
1789 浏览
提问于 2025-04-18 04:16

我正在使用混淆矩阵来检查我的分类器的表现。

我在使用Scikit-Learn,但有点困惑。我该如何理解下面的结果呢:

from sklearn.metrics import confusion_matrix
>>> y_true = [2, 0, 2, 2, 0, 1]
>>> y_pred = [0, 0, 2, 2, 0, 2]
>>> confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
       [0, 0, 1],
       [1, 0, 2]])

我该如何判断这些预测值是否准确呢?

1 个回答

1

判断一个分类器好不好,最简单的方法就是用一些标准的错误测量方法来计算错误率,比如说均方误差(Mean squared error)。我想你的例子是从Scikit的文档上复制的,所以我假设你已经看过相关的定义。

这里我们有三个类别:012。在混淆矩阵的对角线上,你可以看到每个类别被正确预测的次数。比如从对角线上的2 0 2可以看出,类别0被正确分类了2次,类别1从来没有被正确预测过,而类别2被正确分类了2次。

在对角线的上下方有一些数字,这些数字告诉你某个类别(行号对应的类别)被错误分类为另一个类别(列号对应的类别)多少次。举个例子,如果你看第一列,在对角线下方你会看到0 1(在矩阵的左下角)。这里的1表示类别2(最后一行)有一次错误地被分类为0(第一列)。这说明在你的y_true中,有一个标签为2的样本被错误地分类为0,这个错误发生在第一个样本上。

如果你把混淆矩阵中的所有数字加起来,你会得到测试样本的总数(2 + 2 + 1 + 1 = 6,这和y_truey_pred的长度是一样的)。如果你把每一行的数字加起来,就能得到每个标签的样本数量:你可以验证一下,确实在y_pred中有两个0、一个1和三个2

如果你把矩阵中的元素除以这个总数,你就可以知道,比如说类别2的正确识别率大约是66%,而在三分之一的情况下,它会和类别0搞混(这也是混淆矩阵名字的由来)。

总结:

虽然单一的错误测量方法可以衡量整体表现,但通过混淆矩阵你可以判断一些情况,比如:

  • 你的分类器在所有类别上表现都很差

  • 或者它对某些类别处理得很好,而对另一些类别则不行(这提示你可以关注这些特定的数据部分,观察分类器在这些情况下的表现)

  • 它的表现不错,但经常把标签A和B搞混。例如,对于线性分类器,你可能需要检查一下这些类别是否可以线性分开。

等等。

撰写回答