在scikit-learn中使用交叉验证绘制精准率-召回曲线

7 投票
2 回答
9368 浏览
提问于 2025-04-29 22:07

我正在使用交叉验证来评估一个分类器的表现,工具是 scikit-learn,我想绘制精确率-召回率曲线。我在 scikit-learn 的网站上找到了一个例子,可以绘制 PR 曲线,但那个例子没有使用交叉验证来进行评估。

我该如何在使用交叉验证时,在 scikit-learn 中绘制精确率-召回率曲线呢?

我做了以下尝试,但不确定这样做是否正确(伪代码):

for each k-fold:

   precision, recall, _ =  precision_recall_curve(y_test, probs)
   mean_precision += precision
   mean_recall += recall

mean_precision /= num_folds
mean_recall /= num_folds

plt.plot(recall, precision)

你觉得怎么样?

补充:

这样做不行,因为每次折叠后 precisionrecall 数组的大小不同。

有人能帮忙吗?

暂无标签

2 个回答

2

目前这是绘制sklearn分类器的精确率-召回率曲线的最佳方法,使用了交叉验证。最棒的是,它可以为所有类别绘制PR曲线,所以你会看到多条漂亮的曲线。

from scikitplot.classifiers import plot_precision_recall_curve
import matplotlib.pyplot as plt

clf = LogisticRegression()
plot_precision_recall_curve(clf, X, y)
plt.show()

这个函数会自动处理给定数据集的交叉验证,合并所有的预测结果,并计算每个类别的PR曲线以及平均PR曲线。它只需要一行代码,就能帮你搞定所有这些事情。

精确率-召回率曲线

免责声明:请注意,这个方法使用了我自己开发的scikit-plot库。

8

与其在每次交叉验证的折叠后记录精确度和召回率,不如在每次折叠后保存对测试样本的预测结果。接下来,收集所有测试(也就是未使用的数据)预测结果,然后计算精确度和召回率。

 ## let test_samples[k] = test samples for the kth fold (list of list)
 ## let train_samples[k] = test samples for the kth fold (list of list)

 for k in range(0, k):
      model = train(parameters, train_samples[k])
      predictions_fold[k] = predict(model, test_samples[k])

 # collect predictions
 predictions_combined = [p for preds in predictions_fold for p in preds]

 ## let predictions = rearranged predictions s.t. they are in the original order

 ## use predictions and labels to compute lists of TP, FP, FN
 ## use TP, FP, FN to compute precisions and recalls for one run of k-fold cross-validation

在一次完整的k折交叉验证中,预测器对每个样本只会做一次预测。如果你有n个样本,那么你应该得到n个测试预测结果。

(注意:这些预测结果与训练预测不同,因为预测器在做出预测时并没有见过这些样本。)

除非你使用留一法交叉验证,否则k折交叉验证通常需要随机划分数据。理想情况下,你应该进行重复的(和分层的) k折交叉验证。不过,将不同轮次的精确度-召回率曲线结合起来并不简单,因为你不能像处理ROC曲线那样简单地在精确度-召回率点之间进行线性插值(参见Davis和Goadrich 2006)。

我个人是使用Davis-Goadrich方法在PR空间中进行插值(然后进行数值积分)来计算AUC-PR,并通过重复的分层10折交叉验证来比较分类器的AUC-PR估计值。

为了得到一个漂亮的图,我展示了其中一次交叉验证的代表性PR曲线。

当然,还有许多其他方法可以评估分类器的性能,这取决于你的数据集的性质。

例如,如果你的数据集中(二元)标签的比例没有偏差(也就是说大约是50-50),你可以使用更简单的ROC分析结合交叉验证:

从每个折叠中收集预测结果,构建ROC曲线(如之前所述),收集所有的TPR-FPR点(也就是将所有TPR-FPR元组的集合取并集),然后绘制可能平滑处理后的点的组合集。可选地,使用简单的线性插值和复合梯形法计算AUC-ROC进行数值积分。

撰写回答