返回预测的概率向量

1 投票
3 回答
1007 浏览
提问于 2025-04-17 23:44

我正在使用scikit-learn进行分类。请问有没有办法得到一个概率向量,用来表示分类器对它的预测有多自信?我想要的是整个测试集的向量,而不是单个元素的。基本上,我需要这个向量来计算ROC曲线和AUC。

3 个回答

0

在scikit-learn这个库里,对于任何分类的方法,你都可以开启一个叫做probability的选项,然后使用predict_proba这个方法来获取每个元素属于各个类别的概率。举个例子,使用著名的鸢尾花数据集,

from sklearn import svm
from sklearn import datasets

# train set
iris = datasets.load_iris()
X = iris.data[0::2, :2]  
Y = iris.target[0::2]

clf = svm.SVC(probability=True)
clf.fit(X, Y) 

# test set
Z = iris.data[1::2, :2]

Y_predict = clf.predict(Z)
Y_actual = iris.target[1::2]
Y_probas = clf.predict_proba(Z) # probabilities of each classification
1

如果你只是想得到ROC曲线和AUC值,可以看看 sklearn.metrics.roc_auc_score,详细信息可以在这里找到。

根据文档的说明:

>>> import numpy as np
>>> from sklearn.metrics import roc_auc_score
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> roc_auc_score(y_true, y_scores)
0.75

需要注意的是,roc_auc_score 只适用于二分类任务。如果你在处理多分类任务,可能需要为每个类别单独计算 roc_auc_score 值。

1

很多分类器都有一个叫做 decision_function 的方法,或者一个叫做 predict_proba 的方法(或者两者都有),可以用来获取软评分,而不是直接给出硬性的判断。举个例子:

>>> import numpy as np
>>> X = np.random.randn(10, 4)
>>> y = np.random.randint(0, 2, 10)
>>> from sklearn.svm import LinearSVC
>>> svm = LinearSVC().fit(X, y)
>>> svm.decision_function(X)
array([-0.92744332,  0.78697484, -0.71569751, -0.19938963, -0.15521737,
        0.45962204,  0.1326111 ,  0.44614422,  0.95731802,  0.8980536 ])

在这个例子中,值是线性支持向量机(SVM)超平面的有符号距离。predict_proba 有点不同,它返回的是一个概率矩阵,但你可以通过索引来获取一个正概率的向量:

>>> from sklearn.linear_model import LogisticRegression
>>> lr = LogisticRegression().fit(X, y)
>>> lr.predict_proba(X)
array([[ 0.73987796,  0.26012204],
       [ 0.26009545,  0.73990455],
       [ 0.63918314,  0.36081686],
       [ 0.62055698,  0.37944302],
       [ 0.54361598,  0.45638402],
       [ 0.38383357,  0.61616643],
       [ 0.50740302,  0.49259698],
       [ 0.39236783,  0.60763217],
       [ 0.32553896,  0.67446104],
       [ 0.20791651,  0.79208349]])
>>> lr.predict_proba(X)[:, 1]
array([ 0.26012204,  0.73990455,  0.36081686,  0.37944302,  0.45638402,
        0.61616643,  0.49259698,  0.60763217,  0.67446104,  0.79208349])

撰写回答