使用apply_along_axis绘图
我有一个三维的ndarray对象,里面包含了光谱数据(也就是空间的xy维度和一个能量维度)。我想从每个像素提取光谱,并画成折线图。目前,我是通过np.ndenumerate在我感兴趣的轴上进行操作,但速度比较慢。我希望试试np.apply_along_axis
,看看能不能更快,但我总是遇到奇怪的错误。
有效的代码:
# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt
ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume
# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel)
# on a single line plot:
for (x,y) in np.ndenumerate(ar[:,:,1]):
plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')
我理解这基本上是一个循环,这种方法比基于数组的方法效率低,所以我想尝试用np.apply_along_axis
的方法,看看能不能更快。不过这是我第一次尝试Python,我还在摸索怎么用,所以如果这个想法根本上有问题,请指正我!
我想尝试的代码:
# define a function to pass to apply_along_axis
def pa(y,x):
if ~all(np.isnan(y)): # only do the plot if there is actually data there...
plt.plot(x,y,alpha=0.15,color='black')
return
# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)
# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!
出现的错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
12 # pa(ar[1,1,:],ax)
13
---> 14 np.apply_along_axis(pa,2,ar,ax)
//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
101 holdshape = outshape
102 outshape = list(arr.shape)
--> 103 outshape[axis] = len(res)
104 outarr = zeros(outshape, asarray(res).dtype)
105 outarr[tuple(i.tolist())] = res
TypeError: object of type 'NoneType' has no len()
有没有人知道这里出了什么问题?或者有什么建议可以让我做得更好?非常感谢!
1 个回答
apply_along_axis
是用来根据你写的函数生成 一个新数组 的。
你现在返回的是 None
(因为你没有返回任何东西),所以出现了错误。Numpy 会检查你返回的结果的长度,看看它是否适合新数组。
因为你没有用结果来构建新数组,所以其实没必要使用 apply_along_axis
。这样做也不会更快。
不过,你现在的 ndenumerate
语句其实和下面这个是一样的:
import numpy as np
import matplotlib.pyplot as plt
ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')
一般来说,你可能想要做的是:
for pixel in ar.reshape(-1, ar.shape[-1]):
plt.plot(x_values, pixel, ...)
这样你就可以轻松地遍历你超光谱数组中每个像素的光谱。
你这里的瓶颈可能不是你怎么使用数组。用 matplotlib
分别绘制每一条线,参数都一样,这样做会有点低效。
虽然构建会稍微花点时间,但使用 LineCollection
会渲染得快得多。(简单来说,使用 LineCollection
是告诉 matplotlib 不用去检查每条线的属性,而是把它们都交给底层的渲染器以相同的方式绘制。这样你就避免了多个单独的 draw
调用,而是用一个 draw
来处理一个大对象。)
不过,缺点是代码可能会稍微难读一些。
我稍后会加个例子。