绘制ROC曲线 - 索引过多错误
我正在直接从这里获取ROC代码:http://scikit-learn.org/stable/auto_examples/plot_roc.html
我在循环中把类别数量硬编码为46,如你所见,但即使我把它设置为2,我仍然会遇到错误。
# Compute ROC curve and ROC area for each class
tpr = dict()
roc_auc = dict()
for i in range(46):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
错误信息是:
Traceback (most recent call last):
File "C:\Users\app\Documents\Python Scripts\gbc_classifier_test.py", line 150, in <module>
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
IndexError: too many indices
y_pred
的情况如下:
array.shape() 出现错误,元组不可调用
而y_test
只是一个一维数组,和y_pred
类似,只不过它包含的是我问题的真实类别。
我不明白,什么东西的索引太多了?
1 个回答
5
你提到的 y_pred
和 y_test
都是一维的数组,所以用 y_pred[:, i]
和 y_test[:, i]
这样的方式来索引就不对了。因为一维数组只能用一个索引来访问。
所以,你可以直接调用 roc_curve(y_test, y_pred)
来处理。