在numpy中,如何最快地将三维数组的第二维与一维数组相乘?

4 投票
1 回答
2547 浏览
提问于 2025-04-16 23:48

你有一个形状为 (a,b,c) 的数组,想要把第二个维度(也就是 b 这一部分)乘以一个形状为 (b) 的数组。

用一个 for 循环是可以做到的,但有没有更好的方法呢?

比如:

A = np.array(shape=(a,b,c))
B = np.array(shape=(b))

for i in B.shape[0]:
    A[:,i,:]=A[:,i,:]*B[i]

1 个回答

6

使用广播的概念:

A*B[:,np.newaxis]

举个例子:

In [47]: A=np.arange(24).reshape(2,3,4)

In [48]: B=np.arange(3)

In [49]: A*B[:,np.newaxis]
Out[49]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  0,  0,  0],
        [16, 17, 18, 19],
        [40, 42, 44, 46]]])

B[:,np.newaxis]的形状是(3,1)。广播的意思是在左边添加新的维度,所以它变成了(1,3,1)。广播还会在长度为1的维度上重复这些元素。因此,当它和A相乘时,它会进一步变成(2,3,4)的形状。这和A的形状是一样的。然后,乘法就像往常一样,逐个元素进行。

撰写回答