用于scikit学习的高级库gridsearch/cross-evaluation库
easy-gscv的Python项目详细描述
简单的网格搜索/交叉验证
在4行代码中从数据到分数。
这个库允许您通过 自动拆分数据集并同时使用 grid search和cross validation在培训过程中。用户可以自己传递定义参数,也可以让gscv对象 自动选择它们(基于分类器)。
这个库是scikit-learn项目的扩展。
示例:
from sklearn.neural_network import MLPClassifier
from sklearn import datasets
from easy_gscv.classifiers import GSCV
# Create test dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target
clf = MLPClassifier()
# Create model instance
gscv_model = GSCV(clf(), X, y)
# Get score
gscv_model.score()
安装
需要python 3.7+
pip install easy-gscv
创建
from easy_gscv.models import GSCV
clf = LogisticRegression()
gscv_model = GSCV(
clf(), X, y, cv=15, n_jobs=-1, params={
'C': [10, 100],
'penalty': ['l2']
}
)
不需要创建单独的列车/测试数据集,模型会这样做 初始化时自动。 如果未提供任何参数,则对默认集执行网格搜索。 但这些可以被推翻。
可以指定用于交叉验证的折叠数
使用cv
关键字。
要加快培训过程,可以使用n_jobs
参数
设置要使用的CPU内核数(或将其设置为-1
以使用所有可用的内核)。
模型接受sklearn分类器或字符串值。 通过调用“Classifiers”属性,可以获取有效分类器的列表。依次将字符串参数传递给gscv对象可保存 你不必自己导入sklearn分类器。
gscv_model = GSCV('RandomForestClassifier',, X, y)
gscv_model.classifiers
'KNeighborsClassifier',
'RandomForestClassifier',
'GradientBoostingClassifier',
'MLPClassifier',
'LogisticRegression',
得分
gscv_model.score()
对训练数据进行网格搜索。使用score
方法评估
通过对测试数据集进行评分,可以将模型推广到什么程度。
获取最佳估计量
gscv_model.get_best_estimator()
返回最佳得分sklearn分类器(基于训练数据)。 作为一个有效的scikit学习分类器,您可以使用它做任何 你可以用sklearn分类器。
当前支持以下分类器。最终目标是 支持未来所有scikit学习分类器。
- Kneighborsscrifier
- 随机林分类器
- GradientBoostingClassifier
- MLP分类程序
- 逻辑回归
获取健康详细信息
由于交叉验证返回平均值,因此 获得最佳评分分类器的更详细概述。
此方法返回如下所示的表,其中 然后可用于进一步细化 后续运行。
clf = KNeighborsClassifier()
gscv_model = GSCV(clf(), X, y)
gscv_model.get_fit_details()
0.965 (+/-0.026) for {'weights': 'uniform', 'n_neighbors': 3}
0.977 (+/-0.013) for {'weights': 'distance', 'n_neighbors': 3}
0.979 (+/-0.011) for {'weights': 'uniform', 'n_neighbors': 5}
0.979 (+/-0.011) for {'weights': 'distance', 'n_neighbors': 5}
0.976 (+/-0.018) for {'weights': 'uniform', 'n_neighbors': 8}
0.975 (+/-0.018) for {'weights': 'distance', 'n_neighbors': 8}
0.971 (+/-0.022) for {'weights': 'uniform', 'n_neighbors': 12}
0.973 (+/-0.024) for {'weights': 'distance', 'n_neighbors': 12}
0.973 (+/-0.025) for {'weights': 'uniform', 'n_neighbors': 15}