Sklearn StackingClassifier:添加特征作为最终估计器的输入

2024-04-26 18:43:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用管道和堆叠分类器来构建分类管道。在我的设置中,我希望将一些额外的原始特征传递给最终估计器,以及上一级模型的预测。从图表上看,如下所示:

enter image description here

我仍然希望利用管道(除了添加Feat x/y之外,我已经用它设置了所有东西)和StackingClassifier来实现这一点,因为它非常干净地处理端到端的堆叠模型训练。然而,我看不到一个选项,添加原始功能的预测之前的“水平”的模型。有什么好办法吗

注:输入到最终估计器的特征与输入到模型1和模型2的特征不同,因此我不寻找pass_through=True标志


Tags: 模型功能利用管道分类器选项图表水平
1条回答
网友
1楼 · 发布于 2024-04-26 18:43:36

这不是一个快速的特性,但我可以想出两种方法来组合它,同时仍然使用StackingClassifier自动化。每一个都有一些缺点

在预测中添加额外的功能

制作一个虚拟预测器,它通过返回输入来“预测”,并将其用作基本估计器,从而将额外的特征传递给元估计器。使用ColumnTransformer选择基本估计量的特征或传递特征

from sklearn.base import ClassifierMixin, TransformerMixin
from sklearn.pipeline import Pipeline

class IdentityPassthrough(ClassifierMixin):
    def __init__(self):
        pass
    def fit(self, X, y):
        return self
    def predict(self, X):
        return X

partial_passthrough = Pipeline([
    ('pass', ColumnTransformer([('pass', 'passthrough', ['x', 'y'])])),
    ('ident', IdentityPassthrough()),
])
base_features = ColumnTransformer([('pass', 'passthrough', ['a', 'b'])])

model = StackingClassifier(estimators=[
        ('pass', partial_passthrough),
        ('tree', Pipeline([('select', base_features), ('tree', DecisionTreeClassifier())])),
        ('knn', Pipeline([('select', base_features), ('knn', KNeighborsClassifier())])),
    ])

model.fit(X, y)

使用直通并选择基本要素

为元估计量使用一个组合,该组合从基本估计量和期望的额外特征中选择下一个特征到预测。这有点令人担忧,因为你必须知道你得到了正确的列顺序(直到sklearn处理完功能名称)。也就是说,在下面的代码中,特征0和1是预测的概率(如果叠加方法不是唯一的predict_proba,则由于类预测为负,因此需要为1和3!),4和5是目标传递变量(在原始帧中索引为2和3)

base_features = ColumnTransformer([('pass', 'passthrough', ['mean radius', 'mean texture'])])

model = StackingClassifier(
    estimators=[
        ('tree', Pipeline([('select', base_features), ('tree', DecisionTreeClassifier(random_state=42))])),
        ('knn', Pipeline([('select', base_features), ('knn', KNeighborsClassifier())])),
    ],
    final_estimator=Pipeline([
        ('select', ColumnTransformer([('select', 'passthrough', [0, 1, 4, 5])])),
        ('model', LogisticRegression())
    ]),
    passthrough=True,
)

model.fit(X, y)

模型图:

&13; 第13部分,;
<!  style defs (common to the two exports from estimator_html_repr)  >
<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style>

<!  First approach diagram:  >
<div class="sk-top-container"><div class="sk-container"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="a4c7712b-4e69-42ca-b31f-ecbe7b6d1d89" type="checkbox" ><label class="sk-toggleable__label" for="a4c7712b-4e69-42ca-b31f-ecbe7b6d1d89">StackingClassifier</label><div class="sk-toggleable__content"><pre>StackingClassifier(estimators=[('pass',                 Pipeline(steps=[('pass',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'perimeter',                                          'mean '                                          'area'])])),                         ('ident',                         <__main__.IdentityPassthrough object at 0x7f2bfbf1f358>)])),                ('tree',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('tree',                         DecisionTreeClassifier(random_state=42))])),                ('knn',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('knn',                         KNeighborsClassifier())]))])</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item"><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>pass</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="c41ba408-2542-42cf-be5d-d2bdb1f7ca39" type="checkbox" ><label class="sk-toggleable__label" for="c41ba408-2542-42cf-be5d-d2bdb1f7ca39">pass: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean perimeter', 'mean area'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="d71f311c-d151-402b-80a4-97fcb9464d8f" type="checkbox" ><label class="sk-toggleable__label" for="d71f311c-d151-402b-80a4-97fcb9464d8f">pass</label><div class="sk-toggleable__content"><pre>['mean perimeter', 'mean area']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="3e9e3a12-5622-4e56-9171-dbc690ca50d8" type="checkbox" ><label class="sk-toggleable__label" for="3e9e3a12-5622-4e56-9171-dbc690ca50d8">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="b5753df0-b293-4aeb-bbdb-d9adc63b6ac8" type="checkbox" ><label class="sk-toggleable__label" for="b5753df0-b293-4aeb-bbdb-d9adc63b6ac8">IdentityPassthrough</label><div class="sk-toggleable__content"><pre><__main__.IdentityPassthrough object at 0x7f2bfbf1f358></pre></div></div></div></div></div></div></div></div><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>tree</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="4952e340-a144-40cd-897b-dcdee029fecb" type="checkbox" ><label class="sk-toggleable__label" for="4952e340-a144-40cd-897b-dcdee029fecb">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="7efbe86f-2262-4048-81fd-7c652803cf4f" type="checkbox" ><label class="sk-toggleable__label" for="7efbe86f-2262-4048-81fd-7c652803cf4f">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="7ccdbf06-9312-4424-a74e-e1c56b3fbe88" type="checkbox" ><label class="sk-toggleable__label" for="7ccdbf06-9312-4424-a74e-e1c56b3fbe88">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="3a2df016-3631-4cc6-960c-695466268875" type="checkbox" ><label class="sk-toggleable__label" for="3a2df016-3631-4cc6-960c-695466268875">DecisionTreeClassifier</label><div class="sk-toggleable__content"><pre>DecisionTreeClassifier(random_state=42)</pre></div></div></div></div></div></div></div></div><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>knn</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="5aaa1a50-3954-43c0-802d-0679ecfaaa5f" type="checkbox" ><label class="sk-toggleable__label" for="5aaa1a50-3954-43c0-802d-0679ecfaaa5f">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="836d9068-cfb3-4545-a714-6f349403d567" type="checkbox" ><label class="sk-toggleable__label" for="836d9068-cfb3-4545-a714-6f349403d567">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="0c9e6b1f-0f96-4d6f-8efc-e3bada46d6a7" type="checkbox" ><label class="sk-toggleable__label" for="0c9e6b1f-0f96-4d6f-8efc-e3bada46d6a7">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="56bfccc3-bca0-4d87-a377-a81913e4098c" type="checkbox" ><label class="sk-toggleable__label" for="56bfccc3-bca0-4d87-a377-a81913e4098c">KNeighborsClassifier</label><div class="sk-toggleable__content"><pre>KNeighborsClassifier()</pre></div></div></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="57737187-5f4c-4186-ad65-e68cecfe14e8" type="checkbox" ><label class="sk-toggleable__label" for="57737187-5f4c-4186-ad65-e68cecfe14e8">LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div></div></div></div></div></div></div>

<!  Second approach diagram:  >
<div class="sk-top-container"><div class="sk-container"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="034d2534-0d63-4319-bfbf-3b0a7117e00f" type="checkbox" ><label class="sk-toggleable__label" for="034d2534-0d63-4319-bfbf-3b0a7117e00f">StackingClassifier</label><div class="sk-toggleable__content"><pre>StackingClassifier(estimators=[('tree',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('tree',                         DecisionTreeClassifier(random_state=42))])),                ('knn',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('knn',                         KNeighborsClassifier())]))],          final_estimator=Pipeline(steps=[('select',                           ColumnTransformer(transformers=[('select',                                           'passthrough',                                           [0,                                            1,                                            4,                                            5])])),                          ('model',                           LogisticRegression())]),          passthrough=True)</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item"><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>tree</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="50985202-3021-4333-877c-034e62c6e07a" type="checkbox" ><label class="sk-toggleable__label" for="50985202-3021-4333-877c-034e62c6e07a">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="8132ce46-3e0b-42d2-b42b-f0a53d192c07" type="checkbox" ><label class="sk-toggleable__label" for="8132ce46-3e0b-42d2-b42b-f0a53d192c07">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="e1970d86-b28e-41d0-8297-4d2ed67b4b50" type="checkbox" ><label class="sk-toggleable__label" for="e1970d86-b28e-41d0-8297-4d2ed67b4b50">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="fafe6dec-d6a7-4c00-b561-17f3307e4bde" type="checkbox" ><label class="sk-toggleable__label" for="fafe6dec-d6a7-4c00-b561-17f3307e4bde">DecisionTreeClassifier</label><div class="sk-toggleable__content"><pre>DecisionTreeClassifier(random_state=42)</pre></div></div></div></div></div></div></div></div><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>knn</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="8d608e7c-c318-4a67-a9b7-26995a77bcc6" type="checkbox" ><label class="sk-toggleable__label" for="8d608e7c-c318-4a67-a9b7-26995a77bcc6">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="12515389-377c-4fd1-9b50-cf0515dc1919" type="checkbox" ><label class="sk-toggleable__label" for="12515389-377c-4fd1-9b50-cf0515dc1919">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="4fa202f1-74c0-47ba-b34d-dbc4e460eff9" type="checkbox" ><label class="sk-toggleable__label" for="4fa202f1-74c0-47ba-b34d-dbc4e460eff9">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="0ac195fa-6584-4220-a00e-8da1dd09b5de" type="checkbox" ><label class="sk-toggleable__label" for="0ac195fa-6584-4220-a00e-8da1dd09b5de">KNeighborsClassifier</label><div class="sk-toggleable__content"><pre>KNeighborsClassifier()</pre></div></div></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="ede1e052-c7af-4bd3-9da9-2fcc69dd8c86" type="checkbox" ><label class="sk-toggleable__label" for="ede1e052-c7af-4bd3-9da9-2fcc69dd8c86">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('select', 'passthrough', [0, 1, 4, 5])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="8ad23c33-3176-45c6-9504-98299d187eda" type="checkbox" ><label class="sk-toggleable__label" for="8ad23c33-3176-45c6-9504-98299d187eda">select</label><div class="sk-toggleable__content"><pre>[0, 1, 4, 5]</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="9fa0446e-ea2b-4c32-b32b-07c9b2643717" type="checkbox" ><label class="sk-toggleable__label" for="9fa0446e-ea2b-4c32-b32b-07c9b2643717">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden visually" id="450a759e-b194-4bf9-a92b-b296f6c9f527" type="checkbox" ><label class="sk-toggleable__label" for="450a759e-b194-4bf9-a92b-b296f6c9f527">LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div></div></div></div></div></div></div></div></div>
和#13;
和#13;

在威斯康星州乳腺癌数据集in this notebook上查看整个事件

相关问题 更多 >