如何显示每个交叉验证折叠的混淆矩阵?

2024-04-25 02:06:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用Python处理一个多类分类问题。我已经用交叉验证训练了分类器(sklearn.model\选择.cross\u validate)并计算交叉验证每一次所需的指标(精确度、准确度、召回率、F1\u分数)。但是现在我需要保存每个折叠的混淆矩阵。我该怎么做?你知道吗

我已经尝试使用sklearn.model_selection.cross_val_predict来获得预测的类,并使用它和实际的类来计算sklearn.metrics.confusion_matrix,但是我只得到了一个混淆矩阵。你知道吗

scoring_multiclass = {
            'accuracy': 'accuracy',
            'precision_macro': 'precision_macro',
            'recall_macro': 'recall_macro',
            'f1_macro': 'f1_macro'}

cv = StratifiedKFold(n_splits=4)
cv_results = cross_validate(
    estimator=current_pipe,
    X=X,
    y=y,
    scoring=scoring_multiclass,
    cv=cv,
    n_jobs=-1)

mean_cv_results = mean_scores(cv_results)

results = {
    'precision': mean_cv_results['test_precision_macro'],
    'recall': mean_cv_results['test_recall_macro'],
    'f1': mean_cv_results['test_f1_macro'],
    'accuracy': mean_cv_results['test_accuracy'],}

如何显示每个折叠的混淆矩阵?你知道吗


Tags: testmodel矩阵sklearnmean交叉resultscv