绘制ROC曲线 - 索引过多错误

2 投票
1 回答
2073 浏览
提问于 2025-04-18 15:57

我正在直接从这里获取ROC代码:http://scikit-learn.org/stable/auto_examples/plot_roc.html

我在循环中把类别数量硬编码为46,如你所见,但即使我把它设置为2,我仍然会遇到错误。

# Compute ROC curve and ROC area for each class
tpr = dict()
roc_auc = dict()
for i in range(46):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

错误信息是:

Traceback (most recent call last):
  File "C:\Users\app\Documents\Python Scripts\gbc_classifier_test.py", line 150, in <module>
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
IndexError: too many indices

y_pred的情况如下: array.shape() 出现错误,元组不可调用

y_test只是一个一维数组,和y_pred类似,只不过它包含的是我问题的真实类别。

我不明白,什么东西的索引太多了?

1 个回答

5

你提到的 y_predy_test 都是一维的数组,所以用 y_pred[:, i]y_test[:, i] 这样的方式来索引就不对了。因为一维数组只能用一个索引来访问。

所以,你可以直接调用 roc_curve(y_test, y_pred) 来处理。

撰写回答