Spark基于特定列将多行合并为单行,无需执行out groupBy操作

2024-06-16 11:08:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个spark数据框,如下所示,有7k列

+---+----+----+----+----+----+----+
| id|   1|   2|   3|sf_1|sf_2|sf_3|
+---+----+----+----+----+----+----+
|  2|null|null|null| 102| 202| 302|
|  4|null|null|null| 104| 204| 304|
|  1|null|null|null| 101| 201| 301|
|  3|null|null|null| 103| 203| 303|
|  1|  11|  21|  31|null|null|null|
|  2|  12|  22|  32|null|null|null|
|  4|  14|  24|  34|null|null|null|
|  3|  13|  23|  33|null|null|null|
+---+----+----+----+----+----+----+

我想通过合并空行来转换数据帧,如下所示。通过执行groupBy操作,我可以将其合并为一行,但是这种聚合的性能非常差,因为我的表中有7k列

import pyspark.sql.functions as F

(df.groupBy('id').agg(*[F.first(x,ignorenulls=True) for x in df.columns if x!='id'])
.show())
+---+----+----+----+----+----+----+
| id|   1|   2|   3|sf_1|sf_2|sf_3|
+---+----+----+----+----+----+----+
|  1|  11|  21|  31| 101| 201| 301|
|  2|  12|  22|  32| 102| 202| 302|
|  4|  14|  24|  34| 104| 204| 304|
|  3|  13|  23|  33| 103| 203| 303|
+---+----+----+----+----+----+----+

任何其他建议/优化/有效方法。谢谢

更新1:尝试使用自连接后

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-17-b7de100341cc> in <module>
     15 """.format(table_name, query, join_key)
     16 
---> 17 spark.sql(final_query).dropDuplicates().filter(filters).count()

~/quartic/spark-3.0.0-bin-hadoop2.7/python/pyspark/sql/dataframe.py in count(self)
    583         2
    584         """
--> 585         return int(self._jdf.count())
    586 
    587     @ignore_unicode_prefix

~/quartic/spark-3.0.0-bin-hadoop2.7/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1303         answer = self.gateway_client.send_command(command)
   1304         return_value = get_return_value(
-> 1305             answer, self.gateway_client, self.target_id, self.name)
   1306 
   1307         for temp_arg in temp_args:

~/quartic/spark-3.0.0-bin-hadoop2.7/python/pyspark/sql/utils.py in deco(*a, **kw)
    129     def deco(*a, **kw):
    130         try:
--> 131             return f(*a, **kw)
    132         except py4j.protocol.Py4JJavaError as e:
    133             converted = convert_exception(e.java_exception)

~/quartic/spark-3.0.0-bin-hadoop2.7/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o148.count.
: java.lang.StackOverflowError
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:35)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
    at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
    at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
    at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:184)
    at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:47)
    at scala.collection.generic.GenericCompanion.apply(GenericCompanion.scala:53)
    at org.apache.spark.sql.catalyst.expressions.BinaryExpression.children(Expression.scala:533)
    at org.apache.spark.sql.catalyst.trees.TreeNode.containsChild$lzycompute(TreeNode.scala:115)
    at org.apache.spark.sql.catalyst.trees.TreeNode.containsChild(TreeNode.scala:115)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:349)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)

Tags: inorgidsqlapacheplusnulltrees
2条回答

您可以使用如下所示的自连接

from pyspark.sql.types import IntegerType, StructField, StructType

values_arr = [
(2,None, None,None,102, 202, 302),
(4,None, None,None,104, 204, 304),
(1,None, None,None,101, 201, 301),
(3,None, None,None,103, 203, 303),
(1,11, 21,31,None,None,None),
(2,12, 22,32,None,None,None),
(4,14, 24,34,None,None,None),
(3,13, 23,33,None,None,None)
]

sc = spark.sparkContext
rdd = sc.parallelize(values_arr)
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("col_1", IntegerType(), True),
    StructField("col_2", IntegerType(), True),
    StructField("col_3", IntegerType(), True),
    StructField("sf_1", IntegerType(), True),
    StructField("sf_2", IntegerType(), True),
    StructField("sf_3", IntegerType(), True)
])

df = spark.createDataFrame(rdd, schema)
df.show()

//Sample Inpput

+ -+  -+  -+  -+  +  +  +
| id|col_1|col_2|col_3|sf_1|sf_2|sf_3|
+ -+  -+  -+  -+  +  +  +
|  2| null| null| null| 102| 202| 302|
|  4| null| null| null| 104| 204| 304|
|  1| null| null| null| 101| 201| 301|
|  3| null| null| null| 103| 203| 303|
|  1|   11|   21|   31|null|null|null|
|  2|   12|   22|   32|null|null|null|
|  4|   14|   24|   34|null|null|null|
|  3|   13|   23|   33|null|null|null|
+ -+  -+  -+  -+  +  +  +

//Solution
df.createTempView("my_table")
query="select l.id as id,r.col_1 as col_1, r.col_2 as col_2, r.col_3 as col_3, l.sf_1 as sf_1, l.sf_2 as sf_2,l.sf_3 as sf_3 from my_table l, my_table r where l.id=r.id and r.col_1 is not null and l.sf_1 is not null"

spark.sql(query).show()

//Sample output: 
+ -+  -+  -+  -+  +  +  +
| id|col_1|col_2|col_3|sf_1|sf_2|sf_3|
+ -+  -+  -+  -+  +  +  +
|  1|   11|   21|   31| 101| 201| 301|
|  3|   13|   23|   33| 103| 203| 303|
|  4|   14|   24|   34| 104| 204| 304|
|  2|   12|   22|   32| 102| 202| 302|
+ -+  -+  -+  -+  +  +  +

你可以试试这个解决方案。如果速度快,请告诉我

from pyspark.sql.types import IntegerType, StructField, StructType

values = [
(2,None, None,None,102, 202, 302),
(4,None, None,None,104, 204, 304),
(1,None, None,None,101, 201, 301),
(3,None, None,None,103, 203, 303),
(1,11, 21,31,None,None,None),
(2,12, 22,32,None,None,None),
(4,14, 24,34,None,None,None),
(3,13, 23,33,None,None,None)
]

sc = spark.sparkContext
rdd = sc.parallelize(values)
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("col1", IntegerType(), True),
    StructField("col2", IntegerType(), True),
    StructField("col3", IntegerType(), True),
    StructField("sf_1", IntegerType(), True),
    StructField("sf_2", IntegerType(), True),
    StructField("sf_3", IntegerType(), True)
])

data = spark.createDataFrame(rdd, schema)
data.show()
# + -+  +  +  +  +  +  +
# | id|col1|col2|col3|sf_1|sf_2|sf_3|
# + -+  +  +  +  +  +  +
# |  2|null|null|null| 102| 202| 302|
# |  4|null|null|null| 104| 204| 304|
# |  1|null|null|null| 101| 201| 301|
# |  3|null|null|null| 103| 203| 303|
# |  1|  11|  21|  31|null|null|null|
# |  2|  12|  22|  32|null|null|null|
# |  4|  14|  24|  34|null|null|null|
# |  3|  13|  23|  33|null|null|null|
# + -+  +  +  +  +  +  +

data.createOrReplaceTempView("data")
join_key = 'id'
table_name = 'data'
query = "{0}".format(join_key)
filters = ""
for index, column_name in enumerate(data.columns):
    if join_key != column_name:
        query += ",\n\t case when a." + column_name + " is null then b." + column_name + " else a." + column_name + " end as " + column_name 
        filters += "\nAND {0} IS NOT NULL".format(column_name) if index !=1 else " {0} IS NOT NULL".format(column_name) 
final_query ="""
SELECT a.{1}
FROM {0} a INNER JOIN {0} b ON a.{2} = b.{2}
""".format(table_name, query, join_key)
print(final_query)
# SELECT a.id,
#    case when a.col1 is null then b.col1 else a.col1 end as col1,
#    case when a.col2 is null then b.col2 else a.col2 end as col2,
#    case when a.col3 is null then b.col3 else a.col3 end as col3,
#    case when a.sf_1 is null then b.sf_1 else a.sf_1 end as sf_1,
#    case when a.sf_2 is null then b.sf_2 else a.sf_2 end as sf_2,
#    case when a.sf_3 is null then b.sf_3 else a.sf_3 end as sf_3
# FROM data a INNER JOIN data b ON a.id = b.id

print(filters)
#  col1 IS NOT NULL
# AND col2 IS NOT NULL
# AND col3 IS NOT NULL
# AND sf_1 IS NOT NULL
# AND sf_2 IS NOT NULL
# AND sf_3 IS NOT NULL

spark.sql(final_query).dropDuplicates().filter(filters).show()
# + -+  +  +  +  +  +  +
# | id|col1|col2|col3|sf_1|sf_2|sf_3|
# + -+  +  +  +  +  +  +
# |  1|  11|  21|  31| 101| 201| 301|
# |  3|  13|  23|  33| 103| 203| 303|
# |  4|  14|  24|  34| 104| 204| 304|
# |  2|  12|  22|  32| 102| 202| 302|
# + -+  +  +  +  +  +  +

相关问题 更多 >