每组重叠间隔切片之间的总和值

2024-04-26 04:09:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个pyspark数据帧,如下所示:

import pandas as pd
from pyspark.sql import SparkSession

spark = (SparkSession.builder 
         .master("local") 
         .getOrCreate())
spark.conf.set("spark.sql.session.timeZone", "UTC")

INPUT = {
    "idx": [1, 1, 1, 1, 0],
    "consumption": [10.0, 20.0, 30.0, 40.0, 5.0],
    "valid_from": [
        pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-03 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-06 00:00:00+00:00", tz="UTC"),
        pd.Timestamp("2019-01-01 00:00:00+00:00", tz="UTC"),
    ],
    "valid_to": [
        pd.Timestamp("2019-01-02 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-05 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-08 00:00:00+0000", tz="UTC"),
        pd.Timestamp("2019-01-02 00:00:00+00:00", tz="UTC"),
    ],
}
df=pd.DataFrame.from_dict(INPUT)
spark.createDataFrame(df).show()

>>>
   +---+-----------+-------------------+-------------------+
   |idx|consumption|         valid_from|           valid_to|
   +---+-----------+-------------------+-------------------+
   |  1|       10.0|2019-01-01 00:00:00|2019-01-02 00:00:00|
   |  1|       20.0|2019-01-02 00:00:00|2019-01-05 00:00:00|
   |  1|       30.0|2019-01-03 00:00:00|2019-01-05 00:00:00|
   |  1|       40.0|2019-01-06 00:00:00|2019-01-08 00:00:00|
   |  0|       5.0 |2019-01-01 00:00:00|2019-01-02 00:00:00|
   +---+-----------+-------------------+-------------------+

我只想对每个idx的重叠间隔片上的consumption求和:

+---+-------------------+-----------+
|idx|          timestamp|consumption|
+---+-------------------+-----------+
|  1|2019-01-01 00:00:00|       10.0|
|  1|2019-01-02 00:00:00|       20.0|
|  1|2019-01-03 00:00:00|       50.0|
|  1|2019-01-04 00:00:00|       50.0|
|  1|2019-01-05 00:00:00|        0.0|
|  1|2019-01-06 00:00:00|       40.0|
|  1|2019-01-07 00:00:00|       40.0|
|  1|2019-01-08 00:00:00|        0.0|
|  0|2019-01-01 00:00:00|        5.0|
|  0|2019-01-02 00:00:00|        0.0|
+---+-------------------+-----------+

Tags: tofromimportinputsqltimestamptzspark
1条回答
网友
1楼 · 发布于 2024-04-26 04:09:57

您可以使用sequence将间隔扩展为单个天,explode天列表,然后sum每个timestampidxconsumption

from pyspark.sql import functions as F

input=spark.createDataFrame(df)
input.withColumn("all_days", F.sequence("valid_from", F.date_sub("valid_to", 1 ))) \
    .withColumn("timestamp", F.explode("all_days")) \
    .groupBy("idx", "timestamp").sum("consumption") \
    .withColumnRenamed("sum(consumption)", "consumption") \
    .join(input.select("idx", "valid_to").distinct().withColumnRenamed("idx", "idx2"), 
        (F.col("timestamp") == F.col("valid_to")) & (F.col("idx") == F.col("idx2")), "full_outer") \
    .withColumn("idx", F.coalesce("idx", "idx2")) \
    .withColumn("timestamp", F.coalesce("timestamp", "valid_to")) \
    .drop("idx2", "valid_to") \
    .fillna(0.0) \
    .orderBy("idx", "timestamp") \
    .show()

输出:


input=spark.createDataFrame(df)...
+ -+         -+     -+
|idx|          timestamp|consumption|
+ -+         -+     -+
|  0|2019-01-01 00:00:00|        5.0|
|  0|2019-01-02 00:00:00|        0.0|
|  1|2019-01-01 00:00:00|       10.0|
|  1|2019-01-02 00:00:00|       20.0|
|  1|2019-01-03 00:00:00|       50.0|
|  1|2019-01-04 00:00:00|       50.0|
|  1|2019-01-05 00:00:00|        0.0|
|  1|2019-01-06 00:00:00|       40.0|
|  1|2019-01-07 00:00:00|       40.0|
|  1|2019-01-08 00:00:00|        0.0|
+ -+         -+     -+

备注:

  • sequence包括间隔的最后一个值,因此一天必须是substractedvalid_to
  • 然后,使用与原始valid_to值的完全联接,用0.0填充null值,恢复丢失的间隔结束日期

相关问题 更多 >

    热门问题