使用混淆矩阵作为scikit学习中交叉验证的评分指标

2024-04-20 00:30:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我在scikit learn中创建了一个管道

pipeline = Pipeline([
    ('bow', CountVectorizer()),  
    ('classifier', BernoulliNB()), 
])

使用交叉验证计算精度

scores = cross_val_score(pipeline,  # steps to convert raw messages      into models
                     train_set,  # training data
                     label_train,  # training labels
                     cv=5,  # split data randomly into 10 parts: 9 for training, 1 for scoring
                     scoring='accuracy',  # which scoring metric?
                     n_jobs=-1,  # -1 = use all cores = faster
                     )

如何报告混淆矩阵而不是“准确性”?


Tags: fordata管道pipelinetrainingtrainscikitlearn
3条回答

不过,您可以定义一个记分器,它使用混淆矩阵中的特定值。见here [link]。只是引用代码:

def tp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 0]
def tn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 1]
def fp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 0]
def fn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 1]
scoring = {'tp' : make_scorer(tp), 'tn' : make_scorer(tn),
           'fp' : make_scorer(fp), 'fn' : make_scorer(fn)}
cv_results = cross_validate(svm.fit(X, y), X, y, scoring=scoring)

这将对这四个记分器中的每一个执行交叉验证,并返回记分字典cv_results,例如,使用键test_tptest_tn等,其中包含来自每个交叉验证拆分的混淆矩阵值。

由此你可以重建一个平均混淆矩阵,但是Xemacross_val_predict似乎更适合这个。

注意,这实际上不适用于cross_val_score;您需要cross_validate(在scikit learn v0.19中引入)。

旁注:您可以使用这些记分器中的一个来通过网格搜索进行超参数优化。

*编辑:在[1,1]返回真负片,而不是[0,0]

您可以使用cross_val_predictSee the scikit-learn docs)而不是cross_val_score

而不是:

from sklearn.model_selection import cross_val_score
scores = cross_val_score(clf, x, y, cv=10)

你可以:

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_pred = cross_val_predict(clf, x, y, cv=10)
conf_mat = confusion_matrix(y, y_pred)

简短的回答是“你不能”。

您需要理解作为模型选择方法的cross_val_score和交叉验证之间的区别。cross_val_score顾名思义,只对分数起作用。混淆矩阵不是分数,它是对评价过程中发生的事情的一种总结。一个主要的区别是分数应该返回可排序对象,特别是在scikit learn-Afloat中。所以,根据得分,你可以通过简单比较b的得分是否更高来判断方法b和a是否更好。你不能用混淆矩阵来做这个,顾名思义,它是一个矩阵。

如果您想获得多个评估运行(如交叉验证)的混淆矩阵,则必须手动完成,这在scikit learn中并没有那么糟糕,实际上是几行代码。

kf = cross_validation.KFold(len(y), n_folds=5)
for train_index, test_index in kf:

   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]

   model.fit(X_train, y_train)
   print confusion_matrix(y_test, model.predict(X_test))

相关问题 更多 >