高效遍历3D数组?

5 投票
3 回答
5723 浏览
提问于 2025-04-16 23:04

我正在使用Python和Numpy进行一些数据分析。

我有一个很大的三维矩阵(NxNxN),其中每个单元格又是一个3x3的矩阵。我们把这个矩阵叫做data,它的结构大概是这样的:

data[N,N,N,3,3]  

我需要找到所有3x3矩阵的特征值,为此我使用Numpy的eigvals函数,但这个过程非常慢。目前我大致上是这样做的:

for i in range(N):
    for j in range(N):
        for k in range(N):
            a = np.linalg.eigvals(data[i,j,k,:,:])

当N=256时,这个过程大约需要一个小时。有没有什么办法可以让这个过程更高效呢?

非常感谢任何建议!

3 个回答

2

因为所有的计算都是独立的,如果你的电脑有多个核心的处理器,可以使用多进程模块来加快计算速度。

4

我相信在NumPy中有个很好的方法可以做到这一点,但一般来说,itertools.product 比起用嵌套循环遍历范围要快得多。

from itertools import product

for i, j, k in product(xrange(N), xrange(N), xrange(N)):
    a = np.linalg.eigvals(data[i,j,k,:,:])
5

itertools.product 这个工具比起嵌套循环来说,看起来更好看。但我觉得它并不会让你的代码快很多。我的测试结果显示,迭代并不是你代码中最慢的部分。

>>> bigdata = numpy.arange(256 * 256 * 256 * 3 * 3).reshape(256, 256, 256, 3, 3)
>>> %timeit numpy.linalg.eigvals(bigdata[100, 100, 100, :, :])
10000 loops, best of 3: 52.6 us per loop

所以低估了:

>>> .000052 * 256 * 256 * 256 / 60
14.540253866666665

在我这台比较新的电脑上,至少需要14分钟。我们来看看循环需要多长时间……

>>> def just_loops(N):
...     for i in xrange(N):
...         for j in xrange(N):
...             for k in xrange(N):
...                 pass
... 
>>> %timeit just_loops(256)
1 loops, best of 3: 350 ms per loop

正如DSM所说,时间要小得多。就连单独切片数组的工作量也更大:

>>> def slice_loops(N, data):
...     for i in xrange(N):
...         for j in xrange(N):
...             for k in xrange(N):
...                 data[i, j, k, :, :]
... 
>>> %timeit slice_loops(256, bigdata)
1 loops, best of 3: 33.5 s per loop

撰写回答