GridSearchCV不支持多类吗?

2024-06-08 01:07:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用GridSearchCV来处理多类案例,基于以下答案:

Accelerating the prediction

但是我得到了值错误,multiclass format is not supported.

如何将此方法用于多类情况?

以下代码来自上述链接中的答案。

import numpy as np
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, make_scorer

X, y = make_classification(n_samples=3000, n_features=5, weights=[0.1, 0.9, 0.3])

pipe = make_pipeline(StandardScaler(), SVC(kernel='rbf', class_weight='auto'))

param_space = dict(svc__C=np.logspace(-5,0,5), svc__gamma=np.logspace(-2, 2, 10))

accuracy_score, recall_score, roc_auc_score
my_scorer = make_scorer(roc_auc_score, greater_is_better=True)

gscv = GridSearchCV(pipe, param_space, scoring=my_scorer)
gscv.fit(X, y)

print gscv.best_params_

Tags: 答案fromimportmakepipelineisnpsklearn
2条回答

它支持多类 您可以设置scoring = f1.macro的段落,例如:

gsearch1 = GridSearchCV(estimator = est1, param_grid=params_test1, scoring='f1_macro', cv=5, n_jobs=-1)

或评分='roc_auc_ovr'

roc_auc_score上的文档中:

Note: this implementation is restricted to the binary classification task or multilabel classification task in label indicator format.

通过“标签指示符格式”,它们意味着每个标签值都表示为二进制列(而不是单个列中的唯一目标值)。你不想为你的预测者那样做,因为它可能导致非互斥的预测(即,对案例p1同时预测标签2和4,或者对案例p2不预测标签)。

Pick或custom实现一个为多类问题定义良好的评分函数,例如F1 score。就我个人而言,我发现informedness比F1得分更有说服力,比roc-auc得分更容易推广到多类问题。

相关问题 更多 >

    热门问题