如何使用sklearn或类似库的标准函数量化一系列预测的一致性,包括预测置信度

-1 投票
1 回答
38 浏览
提问于 2025-04-12 17:21

假设我让一个分类模型对同一个物体进行多次分类,但每次的情况都不一样。理想情况下,它应该每次都预测出相同的类别。但实际上,它的预测结果可能会有所不同。

所以,对于这个物体的一系列分类预测,我想测量一下这些预测结果的一致性。需要说明的是,这里不是在和某个真实结果进行比较,而是想看看这些预测结果之间的一致性。

  • 比如,一个完全一致的预测序列像 class_a, class_a, class_a, class_a 应该得满分。
  • 而一个不太一致的序列像 class_a, class_b, class_a, class_c 应该得个较低的分数。
  • 再比如,一个完全不一致的序列像 class_a, class_b, class_c, class_d 应该得最低分。

我的目标是找出哪些物体需要继续训练这个分类模型。如果模型对某个物体的预测不够一致,那我们可能需要把这个物体加入到数据集中,进行进一步的训练。

最好是这个方法可以适用于任何数量的类别,并且还要考虑到预测的置信度。比如序列 class_a (0.9), class_b (0.9), class_a (0.9), class_c (0.9) 应该得的分数比 class_a (0.9), class_b (0.2), class_a (0.8), class_c (0.3) 低,因为当预测结果的置信度很高但又不一致时,这样的情况是很糟糕的。

我可以自己动手做一个,但我想知道是否有现成的 sklearn 或 scipy(或者类似的)函数可以用?谢谢!

这个问题 的评论提到了 斯皮尔曼相关系数 或者 肯德尔相关系数。我也会去研究一下这些。

1 个回答

1

不确定这是不是你想要的内容:

import numpy as np
from collections import Counter

def consistency_score(predictions, confidences):
    """
    Calculate a consistency score for a sequence of predictions.
    

    """
    # Calculate base consistency as the frequency of the most common class
    most_common_class, most_common_freq = Counter(predictions).most_common(1)[0]
    base_consistency = most_common_freq / len(predictions)
    
    # Adjust consistency based on confidences
    # Penalize deviations from the most common class, especially with high confidence
    penalty = sum(conf for pred, conf in zip(predictions, confidences) if pred != most_common_class) / len(predictions)
    adjusted_consistency = max(0, base_consistency - penalty)
    
    return adjusted_consistency
  • 举个例子:

      predictions = ["class_a", "class_b", "class_a", "class_c"]
      confidences = [0.9, 0.9, 0.9, 0.9]
      score_high_confidence = consistency_score(predictions, confidences)
    
      predictions_low_confidence = ["class_a", "class_b", "class_a", "class_c"]
      confidences_low_confidence = [0.9, 0.2, 0.8, 0.3]
      score_low_confidence = consistency_score(predictions_low_confidence, confidences_low_confidence)
    
      print(f"High confidence inconsistencies score: {score_high_confidence}")
      print(f"Lower confidence inconsistencies score: {score_low_confidence}")
    

撰写回答