多标签绘制混淆矩阵sklearn

2024-05-17 17:39:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在绘制一个多标签数据的混淆矩阵,其中标签看起来像:

label1: 1, 0, 0, 0

label2: 0, 1, 0, 0

label3: 0, 0, 1, 0

label4: 0, 0, 0, 1

我可以使用下面的代码成功地分类。我只需要一些帮助来绘制混淆矩阵。

    for i in range(4):
        y_train= y[:,i]
        print('Train subject %d, class %s' % (subject, cols[i]))
        lr.fit(X_train[::sample,:],y_train[::sample])
        pred[:,i] = lr.predict_proba(X_test)[:,1]

我使用下面的代码来打印混淆矩阵,但它总是返回一个2X2矩阵

prediction = lr.predict(X_train)

print(confusion_matrix(y_train, prediction))

Tags: 数据sample代码绘制train矩阵标签predict