确保PySpark阵列中相邻元素之间的差异大于给定的最小值

2024-04-25 11:33:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个PySpark数据帧(df),有三列。你知道吗

1。 category:一些字符串

2。 startTimeArray:它是一个数组,包含按升序排列的时间戳。你知道吗

三。 endTimeArray:它是一个数组,包含按升序排列的时间戳。你知道吗

在每一行中,startTimeArray中的数组长度与endTimeArray中的数组长度相同。对于这些数组中的每个索引,startTimeArray中给出的时间戳小于endTimeArray中相应的(相同索引)时间戳(发生在前一个日期)。你知道吗

在列startTimeArray(和列endTimeArray)中,数组的长度可以不同。你知道吗

以下是数据帧的示例:

+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|category|startTimeArray                                                                                           |endTimeArray                                                                                             |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|a       |[2019-01-10 00:00:00, 2019-01-12 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00]                     |[2019-01-11 00:00:00, 2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00]                     |
|a       |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-20 00:00:00, 2019-03-25 00:00:00, 2019-03-27 00:00:00]|[2019-03-16 00:00:00, 2019-03-19 00:00:00, 2019-03-23 00:00:00, 2019-03-26 00:00:00, 2019-03-30 00:00:00]|
|b       |[2019-01-14 00:00:00, 2019-01-16 00:00:00, 2019-02-22 00:00:00]                                          |[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-02-25 00:00:00]                                          |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+

在每一行的startTimeArray列中,我要确保数组中连续元素(连续索引处的元素)之间的差异至少为三天。如果startTimeArray中的一行有n元素,我可以删除数组中的条目,除了第一个条目。此外,如果索引i处的元素从startTimeArray中的行中删除,我希望索引i-1处的元素从endTimeArray中的同一行中删除。**

如何使用PySpark完成此任务?你知道吗

我们需要注意的是:

  1. 如果startTimeArray中的数组有一个元素,我们就让它在那里。

  2. 我意识到这个任务可以通过删除startTimeArray中数组中第一个元素之后的所有元素来实现。那将是一个微不足道的例子。但是我想通过尽可能少的删除来完成任务。

下面是我在上面给出的示例dataframe df中想要的输出。你知道吗

+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|category|startTimeArray                                                 |endTimeArray                                                   |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|a       |[2019-01-10 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00]|[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00]|
|a       |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-25 00:00:00]|[2019-03-16 00:00:00, 2019-03-23 00:00:00, 2019-03-30 00:00:00]|
|b       |[2019-01-14 00:00:00, 2019-02-22 00:00:00]                     |[2019-01-18 00:00:00, 2019-02-25 00:00:00]                     |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+

Tags: 数据字符串元素示例df时间条目数组
1条回答
网友
1楼 · 发布于 2024-04-25 11:33:59

用户定义函数(UDF)可以完成这项工作。虽然与本机sparksql函数相比,它会带来性能损失,但它清楚地表达了所需的操作。你知道吗

from datetime import date, timedelta

from pyspark.sql.functions import *
from pyspark.sql.types import *

d = [date(2019, 1, d) for d in (10, 12, 16, 20)]
e = [date(2019, 1, d) for d in (11, 15, 18, 22)]
f = [date(2019, 3, d) for d in (11, 18, 20, 25, 27)]
g = [date(2019, 3, d) for d in (16, 19, 23, 26, 30)]
h = [date(2019, 1, 14), date(2019, 1, 16), date(2019, 2, 22)]
i = [date(2019, 1, 15), date(2019, 1, 18), date(2019, 2, 25)]

df = spark.createDataFrame((("a", d, e), ("a", f, g), ("b", h, i)),
                           schema=("category", "startDates", "endDates"))


@udf(returnType=ArrayType(ArrayType(DateType())))
def retain_dates_n_days_apart(startDates, endDates, min_apart=3):
    start_dates = [startDates[0]]
    end_dates = []
    for start, end in zip(startDates[1:], endDates):
        if start >= start_dates[-1] + timedelta(days=min_apart):
            start_dates.append(start)
            end_dates.append(end)
    end_dates.append(endDates[-1])
    return start_dates, end_dates


df2 = (df
       .withColumn("foo",
                   retain_dates_n_days_apart(df.startDates,
                                             df.endDates))
       .cache())

(df2.withColumn("startDates", df2.foo.getItem(0))
 .withColumn("endDates", df2.foo.getItem(1))
 .drop("foo")
 ).show(truncate=False)
# +    +                  +                  +
# |category|startDates                          |endDates                            |
# +    +                  +                  +
# |a       |[2019-01-10, 2019-01-16, 2019-01-20]|[2019-01-15, 2019-01-18, 2019-01-22]|
# |a       |[2019-03-11, 2019-03-18, 2019-03-25]|[2019-03-16, 2019-03-23, 2019-03-30]|
# |b       |[2019-01-14, 2019-02-22]            |[2019-01-18, 2019-02-25]            |
# +    +                  +                  +

相关问题 更多 >