将PySpark ML预测与标识数据合并

0 投票
1 回答
14 浏览
提问于 2025-04-14 18:12

我正在使用PySpark和它的机器学习库来构建一个分类模型。在我的输入数据表中,有一列是标识符(叫做 erp_number),我想在构建模型的时候把它排除在外(也就是说,我不想把它作为模型的特征),但在输出预测结果的时候,我想把它加回来。

def create_predictions(data, module):

    data = data.drop("erp_number")

    # Identify categorical columns
    categorical_columns = [field.name for field in data.schema.fields if isinstance(field.dataType, StringType)]

    # Numerical columns, excluding the categorical and target columns
    numerical_columns = [field.name for field in data.schema.fields if field.name not in categorical_columns and field.name != module]

    # Create a list of StringIndexers and OneHotEncoders
    stages = []
    for categorical_col in categorical_columns:
        string_indexer = StringIndexer(inputCol=categorical_col, outputCol=categorical_col + "_index", handleInvalid="keep")
        encoder = OneHotEncoder(inputCols=[string_indexer.getOutputCol()], outputCols=[categorical_col + "_vec"])
        stages += [string_indexer, encoder]

    # Add VectorAssembler to the pipeline stages
    feature_columns = [c + "_vec" for c in categorical_columns] + numerical_columns
    assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
    stages += [assembler]

    # Add the GBTClassifier to the pipeline stages
    gbt = GBTClassifier(labelCol=module, featuresCol="features", predictionCol="prediction")
    stages += [gbt]

    # Create a Pipeline
    pipeline = Pipeline(stages=stages)

    # Fit the pipeline to the data
    model = pipeline.fit(data)

    # Apply the model to the data
    predictions = model.transform(data)

    return predictions

我尝试从数据表中删除这一列。但是看起来没有类似于pandas的concat或者dplyr的bind_cols的功能。我还试着把 erp_numberfeature_columns 列表中排除,但这在处理流程中产生了错误。

1 个回答

0

我终于找到方法了:你只需要在数据框中保留这个字段,但不要把它包含在模型的特征里。这非常方便。我之前试过这个,但忘了把这个字段从数值列中去掉了:

def create_predictions(data, module):

    # Identify categorical columns (excluding id)
    categorical_columns = [field.name for field in data.schema.fields if isinstance(field.dataType, StringType) and field.name != "erp_number"]

    # Numerical columns, excluding id, categorical and target columns
    numerical_columns = [field.name for field in data.schema.fields if field.name not in categorical_columns and field.name not in [module, "erp_number"]]

    # Create a list of StringIndexers and OneHotEncoders
    stages = []
    for categorical_col in categorical_columns:
        string_indexer = StringIndexer(inputCol=categorical_col, outputCol=categorical_col + "_index", handleInvalid="keep")
        encoder = OneHotEncoder(inputCols=[string_indexer.getOutputCol()], outputCols=[categorical_col + "_vec"])
        stages += [string_indexer, encoder]

    # Add VectorAssembler to the pipeline stages
    feature_columns = [c + "_vec" for c in categorical_columns] + numerical_columns
    assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
    stages += [assembler]

    # Add the GBTClassifier to the pipeline stages
    gbt = GBTClassifier(labelCol=module, featuresCol="features", predictionCol="prediction")
    stages += [gbt]

    # Create a Pipeline
    pipeline = Pipeline(stages=stages)

    # Fit the pipeline to the data
    model = pipeline.fit(data)

    # Apply the model to the data
    predictions = model.transform(data)

    return predictions

撰写回答