如何使用Numpy.strides重构滚动总和?

0 投票
3 回答
89 浏览
提问于 2025-04-12 04:19

我有一段代码可以正常运行。我在想怎么用 np.lib.stride_tricks.as_strided 来实现,或者怎么避免使用循环。

import yfinance as yf
import pandas as pd
import numpy as np

# Fetch Apple stock data
apple_data = yf.download('AAPL', start='2024-01-01', end='2024-03-31')

# Extract volume data
apple_volume = apple_data['Volume']

# Resample to ensure every date is included
apple_volume = apple_volume.resample('D').ffill()

# Function to calculate rolling sum with reset using NumPy
def rolling_sum_with_reset(series, window_size):
    rolling_sums = np.zeros(len(series))
    current_sum = 0
    for i, value in enumerate(series):
        if i % window_size == 0:
            current_sum = 0
        current_sum += value
        rolling_sums[i] = current_sum
    return rolling_sums


rolling_3_day_volume = rolling_sum_with_reset(apple_volume, 3)

3 个回答

1

我们可以这样做吗:

def rolling_sum_with_reset(series, window_size):
    shape = series.shape[:-1] + (series.shape[-1] - window_size + 1, window_size)
    strides = series.strides + (series.strides[-1],)
    rolling_strides = np.lib.stride_tricks.as_strided(series, shape=shape, strides=strides)
    rolling_sums = np.sum(rolling_strides, axis=1)
    return rolling_sums.flatten()
3

我建议使用 来加快计算速度:

import numba


@numba.njit
def rolling_sum_with_reset_numba(series, window_size):
    rolling_sums = np.empty_like(series)
    current_sum = 0
    for i, value in enumerate(series):
        if i % window_size == 0:
            current_sum = 0
        current_sum += value
        rolling_sums[i] = current_sum
    return rolling_sums

完整示例:

import numba
import numpy as np
import pandas as pd
from itertools import accumulate

# generate some sample data:
np.random.seed(42)
dr = pd.date_range("2024-01-01", "2024-03-3", freq="1D")

apple_data = pd.DataFrame(
    {"Date": dr, "Volume": np.random.randint(1000, 10_000, size=len(dr))}
).set_index("Date")

# Extract volume data
apple_volume = apple_data["Volume"]


# Function to calculate rolling sum with reset using NumPy
def rolling_sum_with_reset(series, window_size):
    rolling_sums = np.zeros(len(series))
    current_sum = 0
    for i, value in enumerate(series):
        if i % window_size == 0:
            current_sum = 0
        current_sum += value
        rolling_sums[i] = current_sum
    return rolling_sums


@numba.njit
def rolling_sum_with_reset_numba(series, window_size):
    rolling_sums = np.empty_like(series)
    current_sum = 0
    for i, value in enumerate(series):
        if i % window_size == 0:
            current_sum = 0
        current_sum += value
        rolling_sums[i] = current_sum
    return rolling_sums

def rolling_sum_with_reset_accumulate(series, window_size):
    rv = accumulate(
        enumerate(series), lambda x, y: x * (y[0] % 3 > 0) + y[1], initial=0
    )
    next(rv)
    return list(rv)

rolling_3_day_volume = rolling_sum_with_reset(apple_volume, 3)
rolling_3_day_volume_nb = rolling_sum_with_reset_numba(apple_volume.values, 3)

print(rolling_3_day_volume)
print(rolling_3_day_volume_nb)

输出结果:

[ 8270. 10130. 16520.  6191. 12925. 20190.  1466.  6892. 13470.  9322.
 12007. 13776.  7949. 11382. 17693.  6051. 13471. 15655.  5555.  9940.
 17336.  9666. 13224. 22073.  3047.  6794.  7983.  3734.  7739. 13397.
  2899. 11633. 13900.  2528.  7084. 11974.  9838. 16231. 26023.  9433.
 17946. 21558.  8041. 15276. 21762.  8099.  9874. 19100.  4152.  6737.
 11680.  8555. 12628. 14649.  4843. 13832. 21705.  6675.  7836. 13133.
  1995. 10624. 12640.]

[ 8270 10130 16520  6191 12925 20190  1466  6892 13470  9322 12007 13776
  7949 11382 17693  6051 13471 15655  5555  9940 17336  9666 13224 22073
  3047  6794  7983  3734  7739 13397  2899 11633 13900  2528  7084 11974
  9838 16231 26023  9433 17946 21558  8041 15276 21762  8099  9874 19100
  4152  6737 11680  8555 12628 14649  4843 13832 21705  6675  7836 13133
  1995 10624 12640]

使用 perfplot 进行基准测试:

import perfplot


def generate_df(n):
    dr = pd.date_range(
        pd.to_datetime("2024-01-01") - pd.Timedelta(f"{30*n} days"),
        "2024-01-01",
        freq="1D",
    )

    return pd.DataFrame(
        {"Date": dr, "Volume": np.random.randint(1000, 10_000, size=len(dr))}
    ).set_index("Date")


perfplot.show(
    setup=generate_df,
    kernels=[
        lambda df: rolling_sum_with_reset(df["Volume"], 3),
        lambda df: rolling_sum_with_reset_numba(df["Volume"].values, 3),
        lambda df: rolling_sum_with_reset_accumulate(df["Volume"], 3),
    ],
    labels=["original", "numba", "accumulate"],
    n_range=[1, 2, 6, 12, 24, 48, 120, 240],  # number of months
    xlabel="N",
    logx=True,
    logy=True,
    equality_check=np.allclose,
)

在我的机器上(AMD 5700x/Ubuntu 20.04/Python 3.11)生成了这个图表:

pd.__version__='2.2.1'
np.__version__='1.26.4'
numba.__version__='0.59.0'

enter image description here

1
For a sample array of integers:

In [225]: arr = np.arange(90); x=rolling_sum_with_reset(arr,3)
In [226]: x
Out[226]: 
array([  0.,   1.,   3.,   3.,   7.,  12.,   6.,  13.,  21.,   9.,  19.,
        30.,  12.,  25.,  39.,  15.,  31.,  48.,  18.,  37.,  57.,  21.,
        43.,  66.,  24.,  49.,  75.,  27.,  55.,  84.,  30.,  61.,  93.,
        33.,  67., 102.,  36.,  73., 111.,  39.,  79., 120.,  42.,  85.,
       129.,  45.,  91., 138.,  48.,  97., 147.,  51., 103., 156.,  54.,
       109., 165.,  57., 115., 174.,  60., 121., 183.,  63., 127., 192.,
        66., 133., 201.,  69., 139., 210.,  72., 145., 219.,  75., 151.,
       228.,  78., 157., 237.,  81., 163., 246.,  84., 169., 255.,  87.,
       175., 264.])

如果这个 arr 的大小是 windows 的整数倍,就可以把它变成一个二维数组,然后对每一行使用累加和(cumsum)操作:

In [227]: y = np.cumsum(arr.reshape(-1,3), axis=1).ravel()

这样就能得到你想要的结果:

In [228]: np.allclose(x,y)
Out[228]: True

我觉得 cumsum 可以和某种“重置”步骤数组一起使用,但这种重塑(reshape)的方法要简单得多。

as_strided 可以用来创建窗口,但因为你不需要窗口之间有重叠,所以结果和这个 reshape 是一样的:

In [230]: np.lib.stride_tricks.sliding_window_view(arr,3)[::3]
Out[230]: 
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17],
        ...

对于大小为10的窗口

In [237]: rolling_sum_with_reset(arr,10)
Out[237]: 
array([  0.,   1.,   3.,   6.,  10.,  15.,  21.,  28.,  36.,  45.,  10.,
        21.,  33.,  46.,  60.,  75.,  91., 108., 126., 145.,  20.,  41.,
        63.,  86., 110., 135., 161., 188., 216., 245.,  30.,  61.,  93.,
       126., 160., 195., 231., 268., 306., 345.,  40.,  81., 123., 166.,
       210., 255., 301., 348., 396., 445.,  50., 101., 153., 206., 260.,
       315., 371., 428., 486., 545.,  60., 121., 183., 246., 310., 375.,
       441., 508., 576., 645.,  70., 141., 213., 286., 360., 435., 511.,
       588., 666., 745.,  80., 161., 243., 326., 410., 495., 581., 668.,
       756., 845.])

In [238]: y = np.cumsum(arr.reshape(-1,10), axis=1).ravel()

In [239]: y
Out[239]: 
array([  0,   1,   3,   6,  10,  15,  21,  28,  36,  45,  10,  21,  33,
        46,  60,  75,  91, 108, 126, 145,  20,  41,  63,  86, 110, 135,
       161, 188, 216, 245,  30,  61,  93, 126, 160, 195, 231, 268, 306,
       345,  40,  81, 123, 166, 210, 255, 301, 348, 396, 445,  50, 101,
       153, 206, 260, 315, 371, 428, 486, 545,  60, 121, 183, 246, 310,
       375, 441, 508, 576, 645,  70, 141, 213, 286, 360, 435, 511, 588,
       666, 745,  80, 161, 243, 326, 410, 495, 581, 668, 756, 845])

撰写回答