numpy.apply_along_axis的具体功能是什么?
我在一些代码中遇到了 numpy.apply_along_axis 这个函数,但对它的说明书有点搞不懂。
这是说明书中的一个例子:
>>> def new_func(a):
... """Divide elements of a by 2."""
... return a * 0.5
>>> b = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> np.apply_along_axis(new_func, 0, b)
array([[ 0.5, 1. , 1.5],
[ 2. , 2.5, 3. ],
[ 3.5, 4. , 4.5]])
根据我对说明书的理解,我本来以为应该是:
array([[ 0.5, 1. , 1.5],
[ 4 , 5 , 6 ],
[ 7 , 8 , 9 ]])
也就是说,应该是在 [[1,2,3], [4,5,6], [7,8,9]] 这个数组的轴 0 上应用函数,而这个轴对应的是 [1,2,3]。
显然我错了。你能帮我纠正一下吗?
4 个回答
你的错误来源于对轴0的真正含义和apply_along_axis函数用途的两个误解:
误解1:
在你的数组中,轴0并不是[1,2,3],这个可以通过b[0,:]找到,它代表的是轴1的第一个切片。轴0的第一个切片可以通过索引b[:,0]找到,返回的结果是[1,4,7]。你将在轴0上进行计算的3个切片或向量是:
>>> b = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> b[:,0]
array([1, 4, 7])
>>> b[:,1]
array([2, 5, 8])
>>> b[:,2]
array([3, 6, 9])
误解2:
你想要应用的函数实际上是逐元素的。改变轴不会看到任何不同,因为这个操作不是针对轴,而是独立地作用于每个元素,并且不会影响数组的形状:
>>> b*0.5
array([[0.5, 1. , 1.5],
[2. , 2.5, 3. ],
[3.5, 4. , 4.5]])
现在,让我们深入一点,将矩阵平方,以便在每个元素之间增加更多变化,并执行np.diff,正如这个领域中另一个答案所示:
>>> b**=2
>>> b
array([[ 1, 4, 9],
[16, 25, 36],
[49, 64, 81]])
我们来取轴0的每个切片,并对它们应用np.diff:
>>> b[:,0]
array([ 1, 16, 49])
>>> np.diff(b[:,0])
array([15, 33])
>>> b[:,1]
array([ 4, 25, 64])
>>> np.diff(b[:,1])
array([21, 39])
>>> b[:,2]
array([ 9, 36, 81])
>>> np.diff(b[:,2])
array([27, 45])
为了回顾,np.diff函数的作用是:
>>> b[1::,0]-b[0:-1,0]
array([15, 33])
这个函数返回的是X[i+1]-X[i]的差值,其中i的范围是[0,len(X)-1],X是一个向量。
因此,在轴0上应用np.diff函数的结果是:
>>> np.apply_along_axis(np.diff, 0, b)
array([[15, 21, 27],
[33, 39, 45]])
最后,你可能期待的是这个答案,而不是上面那个:
array([[15, 33],
[21, 39],
[27, 45]])
然后回到原始数组,理解np.diff函数是应用于轴0的,也就是从numpy的角度看是垂直的。
这个函数是在一维数组上进行操作的,主要是沿着第一个维度(也就是轴0)。你可以通过“axis”这个参数来指定其他的维度。下面是一个使用这个方法的例子:
np.apply_along_axis(np.cumsum, 0, b)
这个函数会对每个子数组在第0维度上进行操作。所以,它是专门为一维函数设计的,并且对于每个一维输入,它会返回一个一维数组。
另一个例子是:
np.apply_along_axis(np.sum, 0, b)
它会为一维数组提供一个标量输出。其实你也可以在cumsum或者sum中直接设置轴参数来实现上面的效果,但这里的重点是,这个方法可以用于你自己写的任何一维函数。
apply_along_axis
是一个用来对输入数组的1D切片应用指定函数的工具,切片是沿着你指定的轴进行的。在你的例子中,new_func
是在数组的第一个轴的每个切片上进行操作。用一个向量值的函数来说明会更清楚,而不是用标量,像这样:
In [20]: b = np.array([[1,2,3], [4,5,6], [7,8,9]])
In [21]: np.apply_along_axis(np.diff,0,b)
Out[21]:
array([[3, 3, 3],
[3, 3, 3]])
In [22]: np.apply_along_axis(np.diff,1,b)
Out[22]:
array([[1, 1],
[1, 1],
[1, 1]])
在这里,numpy.diff
(也就是相邻数组元素的算术差)是沿着输入数组的第一个或第二个轴(维度)对每个切片进行应用的。