为什么在NumPy中对FFT进行填充会使其运行变慢?

1 投票
2 回答
3164 浏览
提问于 2025-04-28 04:10

我写了一个脚本,使用了NumPy的fft函数。在这个脚本中,我把输入数组的大小调整到最接近的2的幂,这样可以让FFT运行得更快。

在分析代码性能时,我发现FFT的调用耗时最长。于是我试着调整了一下参数,发现如果我对输入数组进行填充,FFT的运行速度竟然快了好几倍。

这里有一个简单的例子来说明我的意思(我在IPython中运行这个,并使用%timeit这个命令来计时)。

x     = np.arange(-4.*np.pi, 4.*np.pi, 1000)
dat1  = np.sin(x)

计时结果:

%timeit np.fft.fft(dat1)
100000 loops, best of 3: 12.3 µs per loop

%timeit np.fft.fft(dat1, n=1024)
10000 loops, best of 3: 61.5 µs per loop

把数组填充到2的幂会导致速度大幅下降。

即使我创建一个元素个数是质数的数组(理论上这会让FFT运行得最慢),

x2    = np.arange(-4.*np.pi, 4.*np.pi, 1009)
dat2  = np.sin(x2)

运行所需的时间也没有变化得那么明显!

%timeit np.fft.fft(dat2)
100000 loops, best of 3: 12.2 µs per loop

我本以为填充数组是一次性的操作,之后计算FFT应该会更快。难道我漏掉了什么吗?

编辑:我应该使用np.linspace而不是np.arange。下面是使用linspace的计时结果。

In [2]: import numpy as np

In [3]: x = np.linspace(-4*np.pi, 4*np.pi, 1000)

In [4]: x2 = np.linspace(-4*np.pi, 4*np.pi, 1024)

In [5]: dat1 = np.sin(x)

In [6]: dat2 = np.sin(x2)

In [7]: %timeit np.fft.fft(dat1)
10000 loops, best of 3: 55.1 µs per loop

In [8]: %timeit np.fft.fft(dat2)
10000 loops, best of 3: 49.4 µs per loop

In [9]: %timeit np.fft.fft(dat1, n=1024)
10000 loops, best of 3: 64.9 µs per loop

填充数组仍然导致速度下降。这会不会是我本地的设置问题?也就是说,可能是因为我NumPy的配置有些奇怪,才会这样?

暂无标签

2 个回答

3

像NumPy这样的快速傅里叶变换(FFT)算法在处理数组时,如果数组的大小可以分解成几个小质数的乘积,速度会很快,而不仅仅是2的幂。如果你通过填充(增加元素)来增大数组的大小,计算的工作量也会增加。FFT算法的速度还和缓存的使用有很大关系。如果填充后的数组大小导致缓存使用效率降低,那么计算速度就会变慢。真正快速的FFT算法,比如FFTW和Intel MKL,会为数组大小的分解生成最佳计算方案,这个过程会结合一些经验法则和实际测量。因此,把数组填充到最近的2的幂在入门教材中可能有帮助,但在实际应用中不一定有效。一般来说,如果数组的大小可以分解成一个或多个很大的质数,填充通常会带来好处。

1

你在用 np.arange,其实应该用 np.linspace

In [2]: x     = np.arange(-4.*np.pi, 4.*np.pi, 1000)

In [3]: x
Out[3]: array([-12.56637061])

np.arange 需要三个参数(起始值,结束值,步长),而 np.linspace 则是(起始值,结束值,点的数量)。当你用你认为的数据进行计算时,结果会是你预期的那样:

In [4]: x = np.linspace(-4.*np.pi, 4.*np.pi, 1000)

In [5]: dat1 = np.sin(x)

In [6]: %timeit np.fft.fft(dat1)
1 loops, best of 3: 28.1 µs per loop

In [7]: %timeit np.fft.fft(dat1, n=1024)
10000 loops, best of 3: 26.7 µs per loop

In [8]: x = np.linspace(-4.*np.pi, 4.*np.pi, 1009)

In [9]: dat2 = np.sin(x)

In [10]: %timeit np.fft.fft(dat2)
10000 loops, best of 3: 53 µs per loop

In [11]: %timeit np.fft.fft(dat2, n=1024)
10000 loops, best of 3: 26.8 µs per loop

撰写回答