高效遍历3D数组?
我正在使用Python和Numpy进行一些数据分析。
我有一个很大的三维矩阵(NxNxN),其中每个单元格又是一个3x3的矩阵。我们把这个矩阵叫做data
,它的结构大概是这样的:
data[N,N,N,3,3]
我需要找到所有3x3矩阵的特征值,为此我使用Numpy的eigvals
函数,但这个过程非常慢。目前我大致上是这样做的:
for i in range(N):
for j in range(N):
for k in range(N):
a = np.linalg.eigvals(data[i,j,k,:,:])
当N=256时,这个过程大约需要一个小时。有没有什么办法可以让这个过程更高效呢?
非常感谢任何建议!
3 个回答
2
因为所有的计算都是独立的,如果你的电脑有多个核心的处理器,可以使用多进程模块来加快计算速度。
4
我相信在NumPy中有个很好的方法可以做到这一点,但一般来说,itertools.product
比起用嵌套循环遍历范围要快得多。
from itertools import product
for i, j, k in product(xrange(N), xrange(N), xrange(N)):
a = np.linalg.eigvals(data[i,j,k,:,:])
5
itertools.product
这个工具比起嵌套循环来说,看起来更好看。但我觉得它并不会让你的代码快很多。我的测试结果显示,迭代并不是你代码中最慢的部分。
>>> bigdata = numpy.arange(256 * 256 * 256 * 3 * 3).reshape(256, 256, 256, 3, 3)
>>> %timeit numpy.linalg.eigvals(bigdata[100, 100, 100, :, :])
10000 loops, best of 3: 52.6 us per loop
所以低估了:
>>> .000052 * 256 * 256 * 256 / 60
14.540253866666665
在我这台比较新的电脑上,至少需要14分钟。我们来看看循环需要多长时间……
>>> def just_loops(N):
... for i in xrange(N):
... for j in xrange(N):
... for k in xrange(N):
... pass
...
>>> %timeit just_loops(256)
1 loops, best of 3: 350 ms per loop
正如DSM所说,时间要小得多。就连单独切片数组的工作量也更大:
>>> def slice_loops(N, data):
... for i in xrange(N):
... for j in xrange(N):
... for k in xrange(N):
... data[i, j, k, :, :]
...
>>> %timeit slice_loops(256, bigdata)
1 loops, best of 3: 33.5 s per loop