sklearn.svm.svc的predict_proba()函数是如何工作的?

46 投票
2 回答
40429 浏览
提问于 2025-04-17 17:18

我正在使用来自scikit-learnsklearn.svm.svc 来进行二分类。 我使用它的 predict_proba() 函数来获取概率估计。 有没有人能告诉我 predict_proba() 是怎么内部计算概率的?

2 个回答

-1

其实我找到了一种稍微不同的答案,他们用这段代码把决策值转换成概率。

'double fApB = decision_value*A+B;
if (fApB >= 0)
    return Math.exp(-fApB)/(1.0+Math.exp(-fApB));
else
     return 1.0/(1+Math.exp(fApB)) ;'

这里的A和B值可以在模型文件中找到(分别是probA和probB)。这提供了一种把概率转换成决策值的方法,从而计算出铰链损失。

记得用ln(0) = -200。

78

Scikit-learn这个库内部使用了LibSVM,而LibSVM又使用了一种叫做Platt缩放的方法,具体细节可以参考LibSVM作者的这篇说明。这个方法的目的是让支持向量机(SVM)不仅能给出分类结果,还能输出概率值。

使用Platt缩放时,首先要像平常一样训练SVM,然后再优化两个参数向量AB,使得

P(y|X) = 1 / (1 + exp(A * f(X) + B))

这里的f(X)表示样本到超平面的有符号距离(可以通过scikit-learn的decision_function方法得到)。你可能会在这个定义中看到逻辑 sigmoid 函数,这也是逻辑回归和神经网络用来将决策函数转化为概率估计的函数。

需要注意的是:B这个参数,也就是“截距”或“偏置”,可能会导致基于这个模型的概率估计的预测结果和通过SVM决策函数f得到的结果不一致。例如,假设f(X) = 10,那么对X的预测是正类;但是如果B = -9.9A = 1,那么P(y|X) = .475。这些数字我随便举的,但你可能会发现这种情况在实际中确实会发生。

实际上,Platt缩放是在SVM的输出上训练一个概率模型,并使用交叉熵损失函数。为了防止这个模型过拟合,它会使用内部的五折交叉验证,这意味着在训练时如果设置probability=True,那么训练SVM的成本会比普通的非概率SVM高很多。

撰写回答