在numpy中,如何最快地将三维数组的第二维与一维数组相乘?
你有一个形状为 (a,b,c) 的数组,想要把第二个维度(也就是 b 这一部分)乘以一个形状为 (b) 的数组。
用一个 for 循环是可以做到的,但有没有更好的方法呢?
比如:
A = np.array(shape=(a,b,c))
B = np.array(shape=(b))
for i in B.shape[0]:
A[:,i,:]=A[:,i,:]*B[i]
1 个回答
6
使用广播的概念:
A*B[:,np.newaxis]
举个例子:
In [47]: A=np.arange(24).reshape(2,3,4)
In [48]: B=np.arange(3)
In [49]: A*B[:,np.newaxis]
Out[49]:
array([[[ 0, 0, 0, 0],
[ 4, 5, 6, 7],
[16, 18, 20, 22]],
[[ 0, 0, 0, 0],
[16, 17, 18, 19],
[40, 42, 44, 46]]])
B[:,np.newaxis]
的形状是(3,1)。广播的意思是在左边添加新的维度,所以它变成了(1,3,1)。广播还会在长度为1的维度上重复这些元素。因此,当它和A
相乘时,它会进一步变成(2,3,4)的形状。这和A
的形状是一样的。然后,乘法就像往常一样,逐个元素进行。