Numpy 广播数组
我在学习Python中的广播机制时,遇到了一个形状不匹配的错误。我知道这意味着我使用的数组在维度上不相符。我的代码基本上是想对以下维度的数组进行这些操作:
(256,256,3)*(256,256)+(256,256)
我知道问题出在乘法上。我在想有没有办法解决这个问题?我可以给(256,256)的数组加一个额外的维度吗?
1 个回答
3
假设我们有
A.shape = (256,256,3)
B.shape = (256,256)
C.shape = (256,256)
NumPy的广播功能默认是在左边添加轴,所以这会导致 B
和 C
被扩展成
B.shape = (256,256,256)
C.shape = (256,256,256)
显然这样是不行的,也不是你想要的,因为它们的形状和 A
不匹配。
所以当你想在右边添加一个轴时,可以使用 B[..., np.newaxis]
和 C[..., np.newaxis]
:
A*B[..., np.newaxis] + C[..., np.newaxis]
B[..., np.newaxis]
的形状是 (256,256,1)
,在和 A
相乘时会扩展成 (256,256,3)
,C[..., np.newaxis]
也是一样的。
B[..., np.newaxis]
也可以写成 B[..., None]
-- 因为 np.newaxis
实际上就是 None
。这样写稍微简短一些,但可能不太容易理解。