pyspark sql函数代替rdd distin

2024-04-24 08:34:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在尝试为特定列替换数据集中的字符串。如果为1或0,则为“Y”,否则为0。在

通过使用lambda的dataframe到rdd的转换,我已经成功地确定了目标列,但这需要一段时间来处理。在

每个列都要切换到rdd,然后执行distinct,这需要一段时间!在

如果不同结果集中存在“Y”,则该列被标识为需要转换。在

我想知道是否有人能建议我如何使用pyspark sql函数来独占地获得相同的结果,而不必为每个列切换?在

示例数据的代码如下:

    import pyspark.sql.types as typ
    import pyspark.sql.functions as func

    col_names = [
        ('ALIVE', typ.StringType()),
        ('AGE', typ.IntegerType()),
        ('CAGE', typ.IntegerType()),
        ('CNT1', typ.IntegerType()),
        ('CNT2', typ.IntegerType()),
        ('CNT3', typ.IntegerType()),
        ('HE', typ.IntegerType()),
        ('WE', typ.IntegerType()),
        ('WG', typ.IntegerType()),
        ('DBP', typ.StringType()),
        ('DBG', typ.StringType()),
        ('HT1', typ.StringType()),
        ('HT2', typ.StringType()),
        ('PREV', typ.StringType())
        ]

    schema = typ.StructType([typ.StructField(c[0], c[1], False) for c in col_names])
    df = spark.createDataFrame([('Y',22,56,4,3,65,180,198,18,'N','Y','N','N','N'),
                                ('N',38,79,3,4,63,155,167,12,'N','N','N','Y','N'),
                                ('Y',39,81,6,6,60,128,152,24,'N','N','N','N','Y')]
                               ,schema=schema)

    cols = [(col.name, col.dataType) for col in df.schema]

    transform_cols = []

    for s in cols:
      if s[1] == typ.StringType():
        distinct_result = df.select(s[0]).distinct().rdd.map(lambda row: row[0]).collect()
        if 'Y' in distinct_result:
          transform_cols.append(s[0])

    print(transform_cols)

输出为:

^{pr2}$

Tags: 数据indfforsqlschematransformcol
1条回答
网友
1楼 · 发布于 2024-04-24 08:34:52

我设法使用udf来完成这项任务。首先,选择带有YN的列(这里我使用func.first来浏览第一行):

cols_sel = df.select([func.first(col).alias(col) for col in df.columns]).collect()[0].asDict()
cols = [col_name for (col_name, v) in cols_sel.items() if v in ['Y', 'N']]
# return ['HT2', 'ALIVE', 'DBP', 'HT1', 'PREV', 'DBG']

接下来,您可以创建udf函数,以便将YN映射到10。在

^{pr2}$

最后,可以对列求和。然后我将输出转换成dictionary并检查哪些列的值大于0(即containsY

out = df.select([func.sum(col).alias(col) for col in cols]).collect()
out = out[0]
print([col_name for (col_name, val) in out.asDict().items() if val > 0])

输出

['DBG', 'HT2', 'ALIVE', 'PREV']

相关问题 更多 >