使用Python GridSearchCV比较插补器方法?

2024-04-28 11:29:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在对泰坦尼克号数据集进行预处理,以便通过一些回归来运行它。 在这种情况下,列车和测试集中的“年龄”列仅为每组中约80%的行填充

我希望使用SimpleImputer(来自sklearn.impute import SimpleImputer)来填充这些列中缺少的值,而不是仅仅删除没有“Age”的行

SimpleImputer的“method”参数有三个选项,用于处理数值数据。这些是平均值、中值和最频繁(模式)。(还有使用自定义值的选项,但因为我试图避免将值“装箱”,所以我不想使用此选项。)

最基本的方法是手动设置所需的数据集。我必须在每一列和测试数据集上运行每种插补器(插补器=SimpleImputer(strategy=“xxxxxx”),其中xxxxxx=‘平均’、‘中值’或‘最频繁’),然后得到六个不同的数据集,然后我必须通过随机森林回归器一次输入一个数据集

我知道GridSearchCV可以用来彻底比较回归器中参数值的各种组合,所以我想知道是否有人知道使用它或类似的方法来运行插补器的各种“方法”选项

我在想一些类似于以下psedoocode的东西-

param_grid = [
    {'method': ['mean','median', 'most frequent']},
]

forest_reg = RandomForestRegressor()
grid_search = GridSearchCV(forest_reg, param_grid, cv = 5, scoring = 'neg_mean_squared_error')

grid_search.fit(titanic_features[method], titanic_values[method])

有没有一个干净的方法来比较这样的选项

有没有更好的方法来比较这三个选项,而不是构建所有六个数据集,通过RF回归器运行它们,看看结果如何


Tags: 数据方法searchparam选项情况regmean
1条回答
网友
1楼 · 发布于 2024-04-28 11:29:41

SklearnPipeline正是为此而设计的。您必须在回归器之前创建一个具有插补器组件的管道。然后可以使用网格搜索参数grid和__传递组件特定的参数

示例代码(内联记录)

# Sample/synthetic data shape 1000 X 2
X = np.random.randn(1000,2)
y = 1.5*X[:,0]+3.2*X[:, 1]+2.4

# Randomly make 200 data points in each axis as nan's
X[np.random.randint(0,1000, 200), 0] = np.nan
X[np.random.randint(0,1000, 200), 1] = np.nan

# Simple pipeline which has an imputer followed by regressor
pipe = Pipeline(steps=[('impute', SimpleImputer(missing_values=np.nan)),
                       ('regressor', RandomForestRegressor())])

# 3 different imputers and 2 different regressors 
# a total of 6 different parameter combination will be searched
param_grid = {
        'impute__strategy': ["mean", "median", "most_frequent"],
        'regressor__max_depth': [2,3]
        }

# Run girdsearch
search = GridSearchCV(pipe, param_grid)
search.fit(X, y)

print("Best parameter (CV score=%0.3f):" % search.best_score_)
print(search.best_params_)

样本输出:

Best parameter (CV score=0.730):
{'impute__strategy': 'median', 'regressor__max_depth': 3}

因此,通过GridSearchCV,我们能够发现样本数据的最佳插补策略是median,如果max_dept的组合为3

您可以继续使用其他组件扩展管道

相关问题 更多 >