NumPy中的轴参数是如何工作的?
有人能解释一下NumPy中的axis
参数到底是干什么的吗?
我真是搞不懂。
我正在尝试使用这个函数myArray.sum(axis=num)
一开始我以为,如果数组本身是三维的,axis=0
会返回三个元素,这三个元素是同一位置上所有嵌套项的总和。如果每个维度都有五个元素,我预计axis=1
会返回五个结果,依此类推。
但是事实并不是这样,文档也没有很好地帮助我 (他们使用的是一个3x3x3的数组,所以很难看出发生了什么)
这是我做的:
>>> e
array([[[1, 0],
[0, 0]],
[[1, 1],
[1, 0]],
[[1, 0],
[0, 1]]])
>>> e.sum(axis = 0)
array([[3, 1],
[1, 1]])
>>> e.sum(axis=1)
array([[1, 0],
[2, 1],
[1, 1]])
>>> e.sum(axis=2)
array([[1, 0],
[2, 1],
[1, 1]])
>>>
显然,结果并不直观。
6 个回答
有些回答太具体,或者没有解决主要的困惑。这个回答试图提供一个更通俗易懂的解释,并给出一个简单的例子。
主要的困惑来源于像“计算均值的轴”这样的表达,这其实是指 numpy.mean
函数中的 axis
参数的说明。那么“沿着哪个轴”到底是什么意思呢?这里的“沿着”其实就是指,如果 axis
是 0,你就会对每一行的数值进行求和(然后再除以行数,得到均值);如果 axis
是 1,那就对每一列的数值进行求和。简单来说,当 axis
是 0(或 1)时,行可以是单个数字、向量,甚至是其他多维数组。
In [1]: import numpy as np
In [2]: a=np.array([[1, 2], [3, 4]])
In [3]: a
Out[3]:
array([[1, 2],
[3, 4]])
In [4]: np.mean(a, axis=0)
Out[4]: array([2., 3.])
In [5]: np.mean(a, axis=1)
Out[5]: array([1.5, 3.5])
所以,在上面的例子中,np.mean(a, axis=0)
返回 array([2., 3.])
,因为 (1 + 3)/2 = 2
和 (2 + 4)/2 = 3
。它返回一个包含两个数字的数组,因为它是计算每一列的行均值(而且这里有两列)。
这里有一些关于可视化的好答案,但从分析的角度思考可能会更有帮助。
你可以使用numpy创建任意维度的数组。比如,这里有一个5维的数组:
>>> a = np.random.rand(2, 3, 4, 5, 6)
>>> a.shape
(2, 3, 4, 5, 6)
你可以通过指定索引来访问这个数组中的任何元素。例如,这里是这个数组的第一个元素:
>>> a[0, 0, 0, 0, 0]
0.0038908603263844155
如果你去掉其中一个维度,你就能得到那个维度中的元素数量:
>>> a[0, 0, :, 0, 0]
array([0.00389086, 0.27394775, 0.26565889, 0.62125279])
当你使用像sum
这样的函数,并且指定axis
参数时,那个维度就会被消除,生成一个维度比原来少的新数组。对于新数组中的每个单元,操作符会获取一组元素,并应用减少函数来得到一个标量值。
>>> np.sum(a, axis=2).shape
(2, 3, 5, 6)
现在你可以检查这个数组的第一个元素,它是上面元素的总和:
>>> np.sum(a, axis=2)[0, 0, 0, 0]
1.1647502999560164
>>> a[0, 0, :, 0, 0].sum()
1.1647502999560164
axis=None
有特别的含义,它会将数组展平,并对所有数字应用函数。
现在你可以考虑更复杂的情况,其中轴不仅仅是一个数字,而是一个元组:
>>> np.sum(a, axis=(2,3)).shape
(2, 3, 6)
注意,我们使用相同的技巧来弄清楚这个减少是如何完成的:
>>> np.sum(a, axis=(2,3))[0,0,0]
7.889432081931909
>>> a[0, 0, :, :, 0].sum()
7.88943208193191
你也可以用同样的思路来为数组添加维度,而不是减少维度:
>>> x = np.random.rand(3, 4)
>>> y = np.random.rand(3, 4)
# New dimension is created on specified axis
>>> np.stack([x, y], axis=2).shape
(3, 4, 2)
>>> np.stack([x, y], axis=0).shape
(2, 3, 4)
# To retrieve item i in stack set i in that axis
希望这能让你对这个重要参数有一个全面而通用的理解。
为了更直观地理解 axis
,可以参考下面的图片(来源:康奈尔大学物理系)
上面图中的(布尔)数组的 形状 是 shape=(8, 3)
。ndarray.shape 会返回一个 元组,这个元组里的每个值对应着特定维度的长度。在我们的例子中,8
是 轴 0 的长度,而 3
是 轴 1 的长度。
很明显,
e.shape == (3, 2, 2)
在某个轴上求和是一种简化操作,所以指定的那个轴会消失。因此,
e.sum(axis=0).shape == (2, 2)
e.sum(axis=1).shape == (3, 2)
e.sum(axis=2).shape == (3, 2)
直观来说,我们是在“压缩”数组沿着选定的轴,把被压缩在一起的数字加起来。