在sklearn预处理中跟踪输出列

2024-04-16 17:29:26 发布

您现在位置:Python中文网/ 问答频道 /正文

如何跟踪sklearn.compose.ColumnTransformer生成的转换数组的列?所谓“保持跟踪”,我的意思是执行反变换所需的每一位信息都必须显式地显示出来。这至少包括以下内容:

  1. 输出数组中每列的源变量是什么?你知道吗
  2. 如果输出数组的一列来自一个分类变量的热编码,那么这个分类是什么?你知道吗
  3. 每个变量的准确插补值是多少?你知道吗
  4. 用于标准化每个数值变量的(平均值,标准差)是多少?(由于插补的缺失值,这些可能与直接计算不同。)

我使用基于this answer的相同方法。我的输入数据集也是一个通用的pandas.DataFrame,有多个数字列和分类列。是的,这个答案可以转换原始数据集。但是我失去了对输出数组中列的跟踪。我需要这些信息的同行审查,报告写作,介绍和进一步的模型建设步骤。我一直在寻找一个系统的方法,但运气不好。你知道吗


Tags: 数据方法composeanswer信息pandas编码分类
1条回答
网友
1楼 · 发布于 2024-04-16 17:29:26

前面提到的答案是基于Sklearn中的this。你知道吗

您可以使用以下代码片段获得前两个问题的答案。你知道吗

def get_feature_names(columnTransformer):

    output_features = []

    for name, pipe, features in columnTransformer.transformers_:
        if name!='remainder':
            for i in pipe:
                trans_features = []
                if hasattr(i,'categories_'):
                    trans_features.extend(i.get_feature_names(features))
                else:
                    trans_features = features
            output_features.extend(trans_features)

    return output_features
import pandas as pd
pd.DataFrame(preprocessor.fit_transform(X_train),
            columns=get_feature_names(preprocessor))

enter image description here

transformed_cols = get_feature_names(preprocessor)

def get_original_column(col_index):
    return transformed_cols[col_index].split('_')[0]

get_original_column(3)
# 'embarked'

get_original_column(0)
# 'age'
def get_category(col_index):
    new_col = transformed_cols[col_index].split('_')
    return 'no category' if len(new_col)<2 else new_col[-1]

print(get_category(3))
# 'Q'

print(get_category(0))
# 'no category'

跟踪是否对某个特性进行了一些插补或缩放,对于当前版本的Sklearn来说并不容易。你知道吗

相关问题 更多 >