我正在绘制一个多标签数据的混淆矩阵,其中标签看起来像:
label1: 1, 0, 0, 0
label2: 0, 1, 0, 0
label3: 0, 0, 1, 0
label4: 0, 0, 0, 1
我可以使用下面的代码成功地分类。我只需要一些帮助来绘制混淆矩阵。
for i in range(4):
y_train= y[:,i]
print('Train subject %d, class %s' % (subject, cols[i]))
lr.fit(X_train[::sample,:],y_train[::sample])
pred[:,i] = lr.predict_proba(X_test)[:,1]
我使用下面的代码来打印混淆矩阵,但它总是返回一个2X2矩阵
prediction = lr.predict(X_train)
print(confusion_matrix(y_train, prediction))
目前没有回答
相关问题 更多 >
编程相关推荐