在多维numpy数组中迭代向量

2 投票
2 回答
2150 浏览
提问于 2025-04-17 16:42

我有一个3xNxM的numpy数组a,我想要遍历最后两个维度,也就是a[:,x,y]。一种不太优雅的方法是:

import numpy as np
a = np.arange(60).reshape((3,4,5))
M = np. array([[1,0,0],
               [0,0,0],
               [0,0,-1]])

for x in arange(a.shape[1]):
    for y in arange(a.shape[2]):
        a[:,x,y] = M.dot(a[:,x,y])

我可以用nditer来实现这个吗?我的目标是对每个元素进行矩阵乘法,比如说a[:,x,y] = M[:,:,x,y].dot(a[:,x,y])。还有一种类似MATLAB的做法是把a重塑为(3,N*M),把M重塑为(3,3*N*M),然后进行点乘,但这样会占用很多内存。

2 个回答

2
for x in np.arange(a.shape[1]):
    for y in np.arange(a.shape[2]):
        a[:,x,y] = M.dot(a[:,x,y])

等价于

a = np.dot(M,a.swapaxes(0,1))

In [73]: np.dot(M,a.swapaxes(0,1))
Out[73]: 
array([[[  0,   1,   2,   3,   4],
        [  5,   6,   7,   8,   9],
        [ 10,  11,  12,  13,  14],
        [ 15,  16,  17,  18,  19]],

       [[  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0]],

       [[-40, -41, -42, -43, -44],
        [-45, -46, -47, -48, -49],
        [-50, -51, -52, -53, -54],
        [-55, -56, -57, -58, -59]]])

解释:

对于多维数组,np.dot(M,a) 是在M的最后一个维度和a的倒数第二个维度上进行求和乘积。

a的形状是 (3,4,5),但我们想要在形状为 3 的维度上进行求和。因为我们要对倒数第二个维度进行求和,所以需要用 a.swapaxis(0,1),它的形状变成 (4,3,5),这样就把 3 移到了倒数第二个维度。

M的形状是 (3,3),而 a.swapaxis(0,1) 的形状是 (4,3,5)。去掉 M 的最后一个维度和 a.swapaxis(0,1) 的倒数第二个维度后,剩下的形状是 (3,) 和 (4,5),所以 np.dot 返回的结果是形状为 (3,4,5) 的数组——正是我们想要的结果。

5

在玩弄形状的时候,可能会让你想要实现的目标变得更加清晰,但处理这类问题最简单的方法是使用 np.einsum,这样你就不用想太多了:

In [5]: np.einsum('ij, jkl', M, a)
Out[5]: 
array([[[  0,   1,   2,   3,   4],
        [  5,   6,   7,   8,   9],
        [ 10,  11,  12,  13,  14],
        [ 15,  16,  17,  18,  19]],

       [[  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0]],

       [[-40, -41, -42, -43, -44],
        [-45, -46, -47, -48, -49],
        [-50, -51, -52, -53, -54],
        [-55, -56, -57, -58, -59]]])

而且它通常还能提高性能:

In [17]: a = np.random.randint(256, size=(3, 1000, 2000))

In [18]: %timeit np.dot(M, a.swapaxes(0,1))
10 loops, best of 3: 116 ms per loop

In [19]: %timeit np.einsum('ij, jkl', M, a)
10 loops, best of 3: 60.7 ms per loop

编辑 einsum 是个非常强大的工具。你也可以像下面评论中的提问者那样来实现:

>>> a = np.arange(60).reshape((3,4,5))
>>> M = np.array([[1,0,0], [0,0,0], [0,0,-1]])
>>> M = M.reshape((3,3,1,1)).repeat(4,axis=2).repeat(5,axis=3)
>>> np.einsum('ijkl,jkl->ikl', M, b)
array([[[  0,   1,   2,   3,   4],
        [  5,   6,   7,   8,   9],
        [ 10,  11,  12,  13,  14],
        [ 15,  16,  17,  18,  19]],

       [[  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0]],

       [[-40, -41, -42, -43, -44],
        [-45, -46, -47, -48, -49],
        [-50, -51, -52, -53, -54],
        [-55, -56, -57, -58, -59]]])

撰写回答