如何解读Scikit-learn混淆矩阵
我正在使用混淆矩阵来检查我的分类器的表现。
我在使用Scikit-Learn,但有点困惑。我该如何理解下面的结果呢:
from sklearn.metrics import confusion_matrix
>>> y_true = [2, 0, 2, 2, 0, 1]
>>> y_pred = [0, 0, 2, 2, 0, 2]
>>> confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
我该如何判断这些预测值是否准确呢?
1 个回答
判断一个分类器好不好,最简单的方法就是用一些标准的错误测量方法来计算错误率,比如说均方误差(Mean squared error)。我想你的例子是从Scikit的文档上复制的,所以我假设你已经看过相关的定义。
这里我们有三个类别:0
、1
和2
。在混淆矩阵的对角线上,你可以看到每个类别被正确预测的次数。比如从对角线上的2 0 2
可以看出,类别0
被正确分类了2次,类别1
从来没有被正确预测过,而类别2
被正确分类了2次。
在对角线的上下方有一些数字,这些数字告诉你某个类别(行号对应的类别)被错误分类为另一个类别(列号对应的类别)多少次。举个例子,如果你看第一列,在对角线下方你会看到0 1
(在矩阵的左下角)。这里的1
表示类别2
(最后一行)有一次错误地被分类为0
(第一列)。这说明在你的y_true
中,有一个标签为2
的样本被错误地分类为0
,这个错误发生在第一个样本上。
如果你把混淆矩阵中的所有数字加起来,你会得到测试样本的总数(2 + 2 + 1 + 1 = 6
,这和y_true
和y_pred
的长度是一样的)。如果你把每一行的数字加起来,就能得到每个标签的样本数量:你可以验证一下,确实在y_pred
中有两个0
、一个1
和三个2
。
如果你把矩阵中的元素除以这个总数,你就可以知道,比如说类别2
的正确识别率大约是66%,而在三分之一的情况下,它会和类别0
搞混(这也是混淆矩阵名字的由来)。
总结:
虽然单一的错误测量方法可以衡量整体表现,但通过混淆矩阵你可以判断一些情况,比如:
你的分类器在所有类别上表现都很差
或者它对某些类别处理得很好,而对另一些类别则不行(这提示你可以关注这些特定的数据部分,观察分类器在这些情况下的表现)
它的表现不错,但经常把标签A和B搞混。例如,对于线性分类器,你可能需要检查一下这些类别是否可以线性分开。
等等。