如何使用PySpark用另一列的值替换列名?

1 投票
2 回答
73 浏览
提问于 2025-04-14 16:23

我有一个PySpark的数据表,内容如下:

ID col1 col2 colA colB
id_1 %colA t < %colA int1 int3
Id_2 %colB t < %colB int2 int4

我想把那些以%开头的字符串替换成对应的列值,结果应该是这样的:

ID col1 col2
id_1 int1 t < int1
Id_2 int4 t < int4

也许我可以通过循环每一行来实现这个功能。

但是有没有更高效的方法来做到这一点呢?

2 个回答

1

这里有一个可能的解决方案,我们可以利用正则表达式的高效性。而且因为我们使用了Spark的函数,Spark内部的优化器也会帮助提高效率,相比于自己写的循环解决方案,这样会更快。

dfSchema = StructType([
    StructField("id", StringType(), True),
    StructField("col1", StringType(), True),
    StructField("col2", StringType(), True),
    StructField("colA", IntegerType(), True),
    StructField("colB", IntegerType(), True),
])
df = spark.createDataFrame([
    ["id_1", "%colA", "t < %colA", 10, 20],
    ["id_2", "%colB", "t < %colB", 30, 40]
], schema=dfSchema)

df.show()

from pyspark.sql.functions import regexp_extract, col, when, regexp_replace
df.withColumn("col3", regexp_extract(col("col1"), r"(?<=\%).*", 0)) \
    .withColumn("col1", when(col("col3") == "colA", col("colA")).when(col("col3") == "colB", col("colB"))) \
    .withColumn("col4", regexp_extract(col("col2"), r"(?<=\%)(.*)", 1))\
    .withColumn("col5", when(col("col4") == "colA", col("colA")).when(col("col4") == "colB", col("colB")))\
    .withColumn("col2", regexp_replace(col("col2"), r"%(.*)[^\s]", col("col5")))\
.show()
+----+-----+---------+----+----+
|  id| col1|     col2|colA|colB|
+----+-----+---------+----+----+
|id_1|%colA|t < %colA|  10|  20|
|id_2|%colB|t < %colB|  30|  40|
+----+-----+---------+----+----+

+----+----+------+----+----+----+----+----+
|  id|col1|  col2|colA|colB|col3|col4|col5|
+----+----+------+----+----+----+----+----+
|id_1|  10|t < 10|  10|  20|colA|colA|  10|
|id_2|  40|t < 40|  30|  40|colB|colB|  40|
+----+----+------+----+----+----+----+----+

为了方便解释,我保留了一些临时的列。你可以把它们删掉。

1

你可以把需要的列转换成 map<string, string> 这种数据类型,这样你就可以用这些列在 map 中查找值了,就像下面这样。

df.withColumn(
  "mp", 
   expr("from_json(to_json(struct(*)),'map<string,string>')")
)
.withColumn(
  "col1", 
   expr("split(col1, '%')")
)
.withColumn(
  "col2", 
   expr("split(col2, '%')")
)
.selectExpr(
    "id", 
    "concat(col1[0],mp[col1[1]]) as col1", 
    "concat(col2[0],mp[col2[1]]) as col2", 
).show(false)
+----+----+------+
|id  |col1|col2  |
+----+----+------+
|id_1|int1|t<int1|
|id_2|int4|t<int4|
+----+----+------+

撰写回答