使用apply_along_axis绘图

3 投票
1 回答
523 浏览
提问于 2025-04-17 22:30

我有一个三维的ndarray对象,里面包含了光谱数据(也就是空间的xy维度和一个能量维度)。我想从每个像素提取光谱,并画成折线图。目前,我是通过np.ndenumerate在我感兴趣的轴上进行操作,但速度比较慢。我希望试试np.apply_along_axis,看看能不能更快,但我总是遇到奇怪的错误。

有效的代码:

# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt

ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume

# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel) 
# on a single line plot:

for (x,y) in np.ndenumerate(ar[:,:,1]):
    plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')

我理解这基本上是一个循环,这种方法比基于数组的方法效率低,所以我想尝试用np.apply_along_axis的方法,看看能不能更快。不过这是我第一次尝试Python,我还在摸索怎么用,所以如果这个想法根本上有问题,请指正我!

我想尝试的代码:

# define a function to pass to apply_along_axis
def pa(y,x):
    if ~all(np.isnan(y)): # only do the plot if there is actually data there...
        plt.plot(x,y,alpha=0.15,color='black')
    return

# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)

# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!

出现的错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
     12 # pa(ar[1,1,:],ax)
     13 
---> 14 np.apply_along_axis(pa,2,ar,ax)

//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
    101         holdshape = outshape
    102         outshape = list(arr.shape)
--> 103         outshape[axis] = len(res)
    104         outarr = zeros(outshape, asarray(res).dtype)
    105         outarr[tuple(i.tolist())] = res

TypeError: object of type 'NoneType' has no len()

有没有人知道这里出了什么问题?或者有什么建议可以让我做得更好?非常感谢!

1 个回答

2

apply_along_axis 是用来根据你写的函数生成 一个新数组 的。

你现在返回的是 None(因为你没有返回任何东西),所以出现了错误。Numpy 会检查你返回的结果的长度,看看它是否适合新数组。

因为你没有用结果来构建新数组,所以其实没必要使用 apply_along_axis。这样做也不会更快。

不过,你现在的 ndenumerate 语句其实和下面这个是一样的:

import numpy as np
import matplotlib.pyplot as plt

ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')

一般来说,你可能想要做的是:

for pixel in ar.reshape(-1, ar.shape[-1]):
    plt.plot(x_values, pixel, ...)

这样你就可以轻松地遍历你超光谱数组中每个像素的光谱。


你这里的瓶颈可能不是你怎么使用数组。用 matplotlib 分别绘制每一条线,参数都一样,这样做会有点低效。

虽然构建会稍微花点时间,但使用 LineCollection 会渲染得快得多。(简单来说,使用 LineCollection 是告诉 matplotlib 不用去检查每条线的属性,而是把它们都交给底层的渲染器以相同的方式绘制。这样你就避免了多个单独的 draw 调用,而是用一个 draw 来处理一个大对象。)

不过,缺点是代码可能会稍微难读一些。

我稍后会加个例子。

撰写回答