从Numpy ndarray中提取指定行列的最快方法是什么?

6 投票
1 回答
1194 浏览
提问于 2025-04-18 21:29

我有一个很大的方阵(大约14,000 x 14,000),它是用Numpy的ndarray表示的。我想提取很多行和列——这些行列的索引我已经提前知道了,实际上就是所有不全是零的行和列——以得到一个新的方阵(大约10,000 x 10,000)。

我找到的最快的方法是:

> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 6.19 s per loop

不过,这个方法的速度比进行矩阵乘法要慢很多:

> timeit np.multiply(A, A)
1 loops, best of 3: 982 ms per loop

这听起来有点奇怪,因为提取行和列以及矩阵乘法都需要分配一个新的数组(而且矩阵乘法的结果会比提取的结果还要大),但矩阵乘法还需要进行额外的计算。

所以,我想问:有没有更有效的方法来进行提取,特别是能和矩阵乘法一样快的方法?

1 个回答

1

如果我尝试复现你遇到的问题,我并没有看到那么明显的效果。我注意到,根据你选择的索引数量,索引的速度甚至可能比乘法还要快。

>>> import numpy as np
>>> np.__version__
Out[1]: '1.9.0'
>>> N = 14000
>>> A = np.random.random(size=[N, N])

>>> indices = np.sort(np.random.choice(np.arange(N), 0.9*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 1.02 s per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 1.37 s per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 748 ms per loop

>>> indices = np.sort(np.random.choice(np.arange(N), 0.7*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 633 ms per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 946 ms per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 728 ms per loop

撰写回答