如何使用Numpy.strides重构滚动总和?
我有一段代码可以正常运行。我在想怎么用 np.lib.stride_tricks.as_strided
来实现,或者怎么避免使用循环。
import yfinance as yf
import pandas as pd
import numpy as np
# Fetch Apple stock data
apple_data = yf.download('AAPL', start='2024-01-01', end='2024-03-31')
# Extract volume data
apple_volume = apple_data['Volume']
# Resample to ensure every date is included
apple_volume = apple_volume.resample('D').ffill()
# Function to calculate rolling sum with reset using NumPy
def rolling_sum_with_reset(series, window_size):
rolling_sums = np.zeros(len(series))
current_sum = 0
for i, value in enumerate(series):
if i % window_size == 0:
current_sum = 0
current_sum += value
rolling_sums[i] = current_sum
return rolling_sums
rolling_3_day_volume = rolling_sum_with_reset(apple_volume, 3)
3 个回答
1
我们可以这样做吗:
def rolling_sum_with_reset(series, window_size):
shape = series.shape[:-1] + (series.shape[-1] - window_size + 1, window_size)
strides = series.strides + (series.strides[-1],)
rolling_strides = np.lib.stride_tricks.as_strided(series, shape=shape, strides=strides)
rolling_sums = np.sum(rolling_strides, axis=1)
return rolling_sums.flatten()
3
我建议使用 numba 来加快计算速度:
import numba
@numba.njit
def rolling_sum_with_reset_numba(series, window_size):
rolling_sums = np.empty_like(series)
current_sum = 0
for i, value in enumerate(series):
if i % window_size == 0:
current_sum = 0
current_sum += value
rolling_sums[i] = current_sum
return rolling_sums
完整示例:
import numba
import numpy as np
import pandas as pd
from itertools import accumulate
# generate some sample data:
np.random.seed(42)
dr = pd.date_range("2024-01-01", "2024-03-3", freq="1D")
apple_data = pd.DataFrame(
{"Date": dr, "Volume": np.random.randint(1000, 10_000, size=len(dr))}
).set_index("Date")
# Extract volume data
apple_volume = apple_data["Volume"]
# Function to calculate rolling sum with reset using NumPy
def rolling_sum_with_reset(series, window_size):
rolling_sums = np.zeros(len(series))
current_sum = 0
for i, value in enumerate(series):
if i % window_size == 0:
current_sum = 0
current_sum += value
rolling_sums[i] = current_sum
return rolling_sums
@numba.njit
def rolling_sum_with_reset_numba(series, window_size):
rolling_sums = np.empty_like(series)
current_sum = 0
for i, value in enumerate(series):
if i % window_size == 0:
current_sum = 0
current_sum += value
rolling_sums[i] = current_sum
return rolling_sums
def rolling_sum_with_reset_accumulate(series, window_size):
rv = accumulate(
enumerate(series), lambda x, y: x * (y[0] % 3 > 0) + y[1], initial=0
)
next(rv)
return list(rv)
rolling_3_day_volume = rolling_sum_with_reset(apple_volume, 3)
rolling_3_day_volume_nb = rolling_sum_with_reset_numba(apple_volume.values, 3)
print(rolling_3_day_volume)
print(rolling_3_day_volume_nb)
输出结果:
[ 8270. 10130. 16520. 6191. 12925. 20190. 1466. 6892. 13470. 9322.
12007. 13776. 7949. 11382. 17693. 6051. 13471. 15655. 5555. 9940.
17336. 9666. 13224. 22073. 3047. 6794. 7983. 3734. 7739. 13397.
2899. 11633. 13900. 2528. 7084. 11974. 9838. 16231. 26023. 9433.
17946. 21558. 8041. 15276. 21762. 8099. 9874. 19100. 4152. 6737.
11680. 8555. 12628. 14649. 4843. 13832. 21705. 6675. 7836. 13133.
1995. 10624. 12640.]
[ 8270 10130 16520 6191 12925 20190 1466 6892 13470 9322 12007 13776
7949 11382 17693 6051 13471 15655 5555 9940 17336 9666 13224 22073
3047 6794 7983 3734 7739 13397 2899 11633 13900 2528 7084 11974
9838 16231 26023 9433 17946 21558 8041 15276 21762 8099 9874 19100
4152 6737 11680 8555 12628 14649 4843 13832 21705 6675 7836 13133
1995 10624 12640]
使用 perfplot
进行基准测试:
import perfplot
def generate_df(n):
dr = pd.date_range(
pd.to_datetime("2024-01-01") - pd.Timedelta(f"{30*n} days"),
"2024-01-01",
freq="1D",
)
return pd.DataFrame(
{"Date": dr, "Volume": np.random.randint(1000, 10_000, size=len(dr))}
).set_index("Date")
perfplot.show(
setup=generate_df,
kernels=[
lambda df: rolling_sum_with_reset(df["Volume"], 3),
lambda df: rolling_sum_with_reset_numba(df["Volume"].values, 3),
lambda df: rolling_sum_with_reset_accumulate(df["Volume"], 3),
],
labels=["original", "numba", "accumulate"],
n_range=[1, 2, 6, 12, 24, 48, 120, 240], # number of months
xlabel="N",
logx=True,
logy=True,
equality_check=np.allclose,
)
在我的机器上(AMD 5700x/Ubuntu 20.04/Python 3.11)生成了这个图表:
pd.__version__='2.2.1'
np.__version__='1.26.4'
numba.__version__='0.59.0'
1
For a sample array of integers:
In [225]: arr = np.arange(90); x=rolling_sum_with_reset(arr,3)
In [226]: x
Out[226]:
array([ 0., 1., 3., 3., 7., 12., 6., 13., 21., 9., 19.,
30., 12., 25., 39., 15., 31., 48., 18., 37., 57., 21.,
43., 66., 24., 49., 75., 27., 55., 84., 30., 61., 93.,
33., 67., 102., 36., 73., 111., 39., 79., 120., 42., 85.,
129., 45., 91., 138., 48., 97., 147., 51., 103., 156., 54.,
109., 165., 57., 115., 174., 60., 121., 183., 63., 127., 192.,
66., 133., 201., 69., 139., 210., 72., 145., 219., 75., 151.,
228., 78., 157., 237., 81., 163., 246., 84., 169., 255., 87.,
175., 264.])
如果这个 arr
的大小是 windows
的整数倍,就可以把它变成一个二维数组,然后对每一行使用累加和(cumsum)操作:
In [227]: y = np.cumsum(arr.reshape(-1,3), axis=1).ravel()
这样就能得到你想要的结果:
In [228]: np.allclose(x,y)
Out[228]: True
我觉得 cumsum
可以和某种“重置”步骤数组一起使用,但这种重塑(reshape)的方法要简单得多。
as_strided
可以用来创建窗口,但因为你不需要窗口之间有重叠,所以结果和这个 reshape
是一样的:
In [230]: np.lib.stride_tricks.sliding_window_view(arr,3)[::3]
Out[230]:
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17],
...
对于大小为10的窗口
In [237]: rolling_sum_with_reset(arr,10)
Out[237]:
array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45., 10.,
21., 33., 46., 60., 75., 91., 108., 126., 145., 20., 41.,
63., 86., 110., 135., 161., 188., 216., 245., 30., 61., 93.,
126., 160., 195., 231., 268., 306., 345., 40., 81., 123., 166.,
210., 255., 301., 348., 396., 445., 50., 101., 153., 206., 260.,
315., 371., 428., 486., 545., 60., 121., 183., 246., 310., 375.,
441., 508., 576., 645., 70., 141., 213., 286., 360., 435., 511.,
588., 666., 745., 80., 161., 243., 326., 410., 495., 581., 668.,
756., 845.])
In [238]: y = np.cumsum(arr.reshape(-1,10), axis=1).ravel()
In [239]: y
Out[239]:
array([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 10, 21, 33,
46, 60, 75, 91, 108, 126, 145, 20, 41, 63, 86, 110, 135,
161, 188, 216, 245, 30, 61, 93, 126, 160, 195, 231, 268, 306,
345, 40, 81, 123, 166, 210, 255, 301, 348, 396, 445, 50, 101,
153, 206, 260, 315, 371, 428, 486, 545, 60, 121, 183, 246, 310,
375, 441, 508, 576, 645, 70, 141, 213, 286, 360, 435, 511, 588,
666, 745, 80, 161, 243, 326, 410, 495, 581, 668, 756, 845])