Pyspark - 创建一个JSON列,其键来自另一个CSV列
输入的数据集长这样
| id | fields | f1 | f2 | f3 | f4 |
| -------- | -------- | -------- | -------- | -------- | -------- |
| 1 | f1, f2, f3 | 3 | 2 | 0 | 2 |
| 2 | f2, f4 | 2 | 4 | 2 | 5 |
| 3 | f1 | 7 | 6 | 4 | 6 |
我们期望的输出是
| id | fields | json_field |
| -------- | -------- | ------------------------------- |
| 1 | f1, f2, f3 | {"f1": 3, "f2": 2, "f3": 0} |
| 2 | f2, f4 | {"f2": 4, "f4": 5} |
| 3 | f1 | {"f1": 7} |
我尝试过
input_df.selet(
col("id"),
col("fields"),
to_json(struct(split(col("fields"), ","))).alias("json_field")
)
但是它没有正常工作。
2 个回答
0
你可以使用 collect_list
和 map_from_entries
这两个 Spark SQL 函数,来根据 fields
列动态创建一个映射(map):
data = [(1, "f1, f2, f3", 3, 2, 0, 2),
(2, "f2, f4", 2, 4, 2, 5),
(3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])
input_df.createOrReplaceTempView("input_view")
// INPUT: input_view
// +---+----------+---+---+---+---+
// |id |fields |f1 |f2 |f3 |f4 |
// +---+----------+---+---+---+---+
// |1 |f1, f2, f3|3 |2 |0 |2 |
// |2 |f2, f4 |2 |4 |2 |5 |
// |3 |f1 |7 |6 |4 |6 |
// +---+----------+---+---+---+---+
output_df = spark.sql(
"""
|WITH exploded_view AS (
| SELECT id, explode(split(fields, ', ')) as field, f1, f2, f3, f4
| FROM input_view
| )
|SELECT
| id,
| collect_list(field) as fields,
| map_from_entries(collect_list(struct(field, CASE field WHEN 'f1' THEN f1 WHEN 'f2' THEN f2 WHEN 'f3' THEN f3 WHEN 'f4' THEN f4 END))) as json_field
| FROM exploded_view
| GROUP BY id, f1, f2, f3, f4
| ORDER BY id
|""".stripMargin)
output_df.show(false)
// OUTPUT:
//+---+------------+---------------------------+
//|id |fields |json_field |
//+---+------------+---------------------------+
//|1 |[f1, f2, f3]|{f1 -> 3, f2 -> 2, f3 -> 0}|
//|2 |[f2, f4] |{f2 -> 4, f4 -> 5} |
//|3 |[f1] |{f1 -> 7} |
//+---+------------+---------------------------+
在 DataFrame API 中,代码可以写成这样:
# Input data
data = [(1, "f1, f2, f3", 3, 2, 0, 2),
(2, "f2, f4", 2, 4, 2, 5),
(3, "f1", 7, 6, 4, 6)]
input_df = spark.createDataFrame(data, ["id", "fields", "f1", "f2", "f3", "f4"])
input_df.show(input_df.count(), False)
# Explode the 'fields' column into multiple rows
exploded_df = input_df.select("id", explode(split(input_df.fields, ", ")).alias("field"), "f1", "f2", "f3", "f4")
# Create the 'json_field' column using the 'field' column and the 'f1', 'f2', 'f3', 'f4' columns with map_from_entries
output_df = (exploded_df.groupBy("id", "f1", "f2", "f3", "f4")
.agg(collect_list("field").alias("fields"),
map_from_entries(collect_list(struct("field",
when(exploded_df.field == "f1", exploded_df.f1)
.when(exploded_df.field == "f2", exploded_df.f2)
.when(exploded_df.field == "f3", exploded_df.f3)
.when(exploded_df.field == "f4", exploded_df.f4)
)
)
).alias("json_field")
)
.orderBy("id")
.select("id", "fields", "json_field"))
output_df.show(output_df.count(), truncate=False)
2
这里有一个简单的解决方案。按照以下步骤操作:
- 把
columns
转换成map<string, string>
类型,可以使用from_json(to_json(struct(*)), 'map<string, string>')
这个方法。 - 把
fields
拆分开来,获取每一个字段,然后在map<string, string>
中查找它。
df
.withColumn("keys", expr("from_json(to_json(struct(*)), 'map<string, string>')"))
.selectExpr(
"id",
"fields",
"transform(split(fields, ','), field -> keys[field]) AS json_fields"
)
.show(false)
+---+--------+-----------+
|id |fields |json_fields|
+---+--------+-----------+
|1 |f1,f2,f3|[3, 2, 0] |
|2 |f2,f4 |[4, 5] |
|3 |f1 |[7] |
+---+--------+-----------+