如果条件为m,则在Pyspark中合并两行

2024-05-12 21:00:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个PySpark数据表,如下所示

shouldMerge | number
true        | 1
true        | 1
true        | 2
false       | 3
false       | 1 

我想将所有shouldMerge为true的列合并起来,然后将数字相加。你知道吗

所以最终的输出看起来像

shouldMerge | number
true        | 4
false       | 3
false       | 1

如何选择shouldMerge==true的所有行,将这些数字相加,并在PySpark中生成新行?你知道吗

编辑:另一个稍微复杂一点的场景,更接近我要解决的问题,我们只聚合正数:

mergeId     | number
1           | 1
2           | 1
1           | 2
-1          | 3
-1          | 1 

shouldMerge | number
1        | 3
2        | 1
-1       | 3
-1       | 1

Tags: falsetrue编辑number场景数字pyspark数据表
3条回答

IIUC,你想做一个groupBy,但只在正的mergeId

一种方法是过滤数据帧中的正id、group、aggregate,并将其与负的id合并(类似于@shanmuga's answer)。你知道吗

另一种方法是使用when动态创建分组键。如果mergeId为阳性,则使用mergeId进行分组。否则,请使用monotonically_increasing_id来确保该行不会被聚合。你知道吗

举个例子:

import pyspark.sql.functions as f

df.withColumn("uid", f.monotonically_increasing_id())\
    .groupBy(
        f.when(
            f.col("mergeId") > 0, 
            f.col("mergeId")
        ).otherwise(f.col("uid")).alias("mergeKey"), 
        f.col("mergeId")
    )\
    .agg(f.sum("number").alias("number"))\
    .drop("mergeKey")\
    .show()
#+   -+   +
#|mergeId|number|
#+   -+   +
#|     -1|   1.0|
#|      1|   3.0|
#|      2|   1.0|
#|     -1|   3.0|
#+   -+   +

通过改变when条件(在本例中是f.col("mergeId") > 0)来满足您的特定需求,可以很容易地概括这一点。你知道吗


解释:

首先,我们创建一个临时列uid,它是每一行的唯一ID。接下来,我们调用groupBy,如果mergeId为正,则使用mergeId进行分组。否则我们使用uid作为mergeKey。我还将mergeId作为第二个groupby列传入,作为为输出保留该列的方法。你知道吗

要演示正在进行的操作,请查看中间结果:

df.withColumn("uid", f.monotonically_increasing_id())\
    .withColumn(
        "mergeKey",
        f.when(
            f.col("mergeId") > 0, 
            f.col("mergeId")
        ).otherwise(f.col("uid")).alias("mergeKey")
    )\
    .show()
#+   -+   +     -+     -+
#|mergeId|number|        uid|   mergeKey|
#+   -+   +     -+     -+
#|      1|     1|          0|          1|
#|      2|     1| 8589934592|          2|
#|      1|     2|17179869184|          1|
#|     -1|     3|25769803776|25769803776|
#|     -1|     1|25769803777|25769803777|
#+   -+   +     -+     -+

如您所见,mergeKey仍然是负的mergeId的唯一值

在这个中间步骤中,所需的结果只是一个简单的group by and sum,然后删除mergeKey列。你知道吗

您必须只筛选出应该合并为true并聚合的行。然后将其与所有剩余行合并。你知道吗

import pyspark.sql.functions as functions
df = sqlContext.createDataFrame([
    (True, 1),
    (True, 1),
    (True, 2),
    (False, 3),
    (False, 1),
], ("shouldMerge", "number"))

false_df = df.filter("shouldMerge = false")
true_df = df.filter("shouldMerge = true")
result = true_df.groupBy("shouldMerge")\
    .agg(functions.sum("number").alias("number"))\
    .unionAll(false_df)




df = sqlContext.createDataFrame([
    (1, 1),
    (2, 1),
    (1, 2),
    (-1, 3),
    (-1, 1),
], ("mergeId", "number"))

merge_condition = df["mergeId"] > -1
remaining = ~merge_condition
grouby_field = "mergeId"

false_df = df.filter(remaining)
true_df = df.filter(merge_condition)
result = true_df.groupBy(grouby_field)\
    .agg(functions.sum("number").alias("number"))\
    .unionAll(false_df)

result.show()

OP发布的第一个问题。

# Create the DataFrame
valuesCol = [(True,1),(True,1),(True,2),(False,3),(False,1)]
df = sqlContext.createDataFrame(valuesCol,['shouldMerge','number'])
df.show()
+     -+   +
|shouldMerge|number|
+     -+   +
|       true|     1|
|       true|     1|
|       true|     2|
|      false|     3|
|      false|     1|
+     -+   +

# Packages to be imported
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
# Register the dataframe as a view
df.registerTempTable('table_view')
df=sqlContext.sql(
    'select shouldMerge, number, sum(number) over (partition by shouldMerge) as sum_number from table_view'
)
df = df.withColumn('number',when(col('shouldMerge')==True,col('sum_number')).otherwise(col('number')))
df.show()
+     -+   +     +
|shouldMerge|number|sum_number|
+     -+   +     +
|       true|     4|         4|
|       true|     4|         4|
|       true|     4|         4|
|      false|     3|         4|
|      false|     1|         4|
+     -+   +     +

df = df.drop('sum_number')
my_window = Window.partitionBy().orderBy('shouldMerge')
df = df.withColumn('shouldMerge_lag', lag(col('shouldMerge'),1).over(my_window))
df.show()
+     -+   +       -+
|shouldMerge|number|shouldMerge_lag|
+     -+   +       -+
|      false|     3|           null|
|      false|     1|          false|
|       true|     4|          false|
|       true|     4|           true|
|       true|     4|           true|
+     -+   +       -+

df = df.where(~((col('shouldMerge')==True) & (col('shouldMerge_lag')==True))).drop('shouldMerge_lag')
df.show()
+     -+   +
|shouldMerge|number|
+     -+   +
|      false|     3|
|      false|     1|
|       true|     4|
+     -+   +

对于OP发布的第二个问题

# Create the DataFrame
valuesCol = [(1,2),(1,1),(2,1),(1,2),(-1,3),(-1,1)]
df = sqlContext.createDataFrame(valuesCol,['mergeId','number'])
df.show()
+   -+   +
|mergeId|number|
+   -+   +
|      1|     2|
|      1|     1|
|      2|     1|
|      1|     2|
|     -1|     3|
|     -1|     1|
+   -+   +

# Packages to be imported
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
# Register the dataframe as a view
df.registerTempTable('table_view')
df=sqlContext.sql(
    'select mergeId, number, sum(number) over (partition by mergeId) as sum_number from table_view'
)
df = df.withColumn('number',when(col('mergeId') > 0,col('sum_number')).otherwise(col('number')))
df.show()
+   -+   +     +
|mergeId|number|sum_number|
+   -+   +     +
|      1|     5|         5|
|      1|     5|         5|
|      1|     5|         5|
|      2|     1|         1|
|     -1|     3|         4|
|     -1|     1|         4|
+   -+   +     +

df = df.drop('sum_number')
my_window = Window.partitionBy('mergeId').orderBy('mergeId')
df = df.withColumn('mergeId_lag', lag(col('mergeId'),1).over(my_window))
df.show()
+   -+   +     -+
|mergeId|number|mergeId_lag|
+   -+   +     -+
|      1|     5|       null|
|      1|     5|          1|
|      1|     5|          1|
|      2|     1|       null|
|     -1|     3|       null|
|     -1|     1|         -1|
+   -+   +     -+

df = df.where(~((col('mergeId') > 0) & (col('mergeId_lag').isNotNull()))).drop('mergeId_lag')
df.show()
+   -+   +
|mergeId|number|
+   -+   +
|      1|     5|
|      2|     1|
|     -1|     3|
|     -1|     1|
+   -+   +

文档:lag()-返回当前行之前偏移行的值。你知道吗

相关问题 更多 >