今天我试图从我的分类模型中绘制出混淆矩阵。
在搜索了一些页面后,我发现来自pyplot
的matshow
可以帮助我。
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title(title)
fig.colorbar(cax)
if labels:
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
如果我没有什么标签的话,效果很好
y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a']
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)
但是如果我有很多标签,有些标签显示不正确
y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)
我的问题是如何在matshow plot中显示所有标签?我试过类似fontdict
的方法,但仍然不起作用
可以使用^{} 模块控制滴答声的频率。
在本例中,您需要设置}
1
的每一个倍数的勾号,这样我们就可以使用^{在调用
plt.show()
之前添加这两行:它将为您的
y_true
和y_pred
中的每个字母生成一个勾号和标签。我还更改了您的
matshow
调用,以使用您在函数调用中指定的colormap:为了完整起见,整个函数将如下所示:
可以使用^{} 方法指定标签。您的函数将如下所示(根据上述答案修改函数):
相关问题 更多 >
编程相关推荐