Pyspark - 创建一个JSON列,其键来自另一个CSV列

1 投票
2 回答
57 浏览
提问于 2025-04-14 17:19

输入的数据集长这样

 |    id     | fields     | f1       | f2         | f3       | f4         |
 | --------  | --------   | -------- | --------   | -------- | --------   |
 |     1     | f1, f2, f3 | 3        |  2         |     0    |     2      |
 |     2     | f2, f4     | 2        |  4         |     2    |     5      |
 |     3     | f1         | 7        |  6         |     4    |     6      |

我们期望的输出是

 |    id     | fields     | json_field                      |
 | --------  | --------   | ------------------------------- |    
 |     1     | f1, f2, f3 | {"f1": 3, "f2": 2, "f3": 0}     |     
 |     2     | f2, f4     | {"f2": 4, "f4": 5}              |     
 |     3     | f1         | {"f1": 7}                       |

我尝试过

    input_df.selet(
        col("id"),
        col("fields"),
        to_json(struct(split(col("fields"), ","))).alias("json_field")
    )

但是它没有正常工作。

2 个回答

0

你可以使用 collect_listmap_from_entries 这两个 Spark SQL 函数,来根据 fields 列动态创建一个映射(map):

data = [(1, "f1, f2, f3", 3, 2, 0, 2),
        (2, "f2, f4", 2, 4, 2, 5),
        (3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])

input_df.createOrReplaceTempView("input_view")
// INPUT: input_view
// +---+----------+---+---+---+---+
// |id |fields    |f1 |f2 |f3 |f4 |
// +---+----------+---+---+---+---+
// |1  |f1, f2, f3|3  |2  |0  |2  |
// |2  |f2, f4    |2  |4  |2  |5  |
// |3  |f1        |7  |6  |4  |6  |
// +---+----------+---+---+---+---+

output_df = spark.sql(
  """
    |WITH exploded_view AS (
    |  SELECT id, explode(split(fields, ', ')) as field, f1, f2, f3, f4
    |  FROM input_view
    |  )
    |SELECT
    | id,
    | collect_list(field) as fields,
    | map_from_entries(collect_list(struct(field, CASE field WHEN 'f1' THEN f1 WHEN 'f2' THEN f2 WHEN 'f3' THEN f3 WHEN 'f4' THEN f4 END))) as json_field
    | FROM exploded_view
    | GROUP BY id, f1, f2, f3, f4
    | ORDER BY id
    |""".stripMargin)

output_df.show(false)
// OUTPUT:
//+---+------------+---------------------------+
//|id |fields      |json_field                 |
//+---+------------+---------------------------+
//|1  |[f1, f2, f3]|{f1 -> 3, f2 -> 2, f3 -> 0}|
//|2  |[f2, f4]    |{f2 -> 4, f4 -> 5}         |
//|3  |[f1]        |{f1 -> 7}                  |
//+---+------------+---------------------------+

在 DataFrame API 中,代码可以写成这样:

# Input data
data = [(1, "f1, f2, f3", 3, 2, 0, 2),
        (2, "f2, f4", 2, 4, 2, 5),
        (3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])

input_df.show(input_df.count(), False)

# Explode the 'fields' column into multiple rows
exploded_df = input_df.select("id", explode(split(input_df.fields, ", ")).alias("field"), "f1", "f2", "f3", "f4")

# Create the 'json_field' column using the 'field' column and the 'f1', 'f2', 'f3', 'f4' columns with map_from_entries
output_df = (exploded_df.groupBy("id", "f1", "f2", "f3", "f4")
             .agg(collect_list("field").alias("fields"),
                  map_from_entries(collect_list(struct("field",
                                                       when(exploded_df.field == "f1", exploded_df.f1)
                                                       .when(exploded_df.field == "f2", exploded_df.f2)
                                                       .when(exploded_df.field == "f3", exploded_df.f3)
                                                       .when(exploded_df.field == "f4", exploded_df.f4)
                                                       )
                                                )
                                   ).alias("json_field")
                  )
             .orderBy("id")
             .select("id", "fields", "json_field"))

output_df.show(output_df.count(), truncate=False)
2

这里有一个简单的解决方案。按照以下步骤操作:

  • columns 转换成 map<string, string> 类型,可以使用 from_json(to_json(struct(*)), 'map<string, string>') 这个方法。
  • fields 拆分开来,获取每一个字段,然后在 map<string, string> 中查找它。
df
.withColumn("keys", expr("from_json(to_json(struct(*)), 'map<string, string>')"))
.selectExpr(
   "id",
   "fields",
   "transform(split(fields, ','), field -> keys[field]) AS json_fields"
)
.show(false)

+---+--------+-----------+
|id |fields  |json_fields|
+---+--------+-----------+
|1  |f1,f2,f3|[3, 2, 0]  |
|2  |f2,f4   |[4, 5]     |
|3  |f1      |[7]        |
+---+--------+-----------+

撰写回答