Numpy多维数组切片
假设我定义了一个3x3x3的numpy数组,代码如下:
x = numpy.arange(27).reshape((3, 3, 3))
现在,我可以通过 x[:, 0, 1]
来获取每个3x3子数组中(0,1)这个位置的元素,这样会返回 array([ 1, 10, 19])
。如果我有一个元组(m,n),想要获取每个子数组中(m,n)这个位置的元素,该怎么做呢?
比如,我有 t = (0, 1)
。我尝试了 x[:, t]
,但是结果不对——它返回的是每个子数组的第0行和第1行。到目前为止,我找到的最简单的解决办法是:
x.transpose()[tuple(reversed(t))].transpose()
不过我相信一定还有更好的方法。当然,在这种情况下,我可以用 x[:, t[0], t[1]]
来实现,但这不能推广到我不知道 x
和 t
有多少维度的情况。
2 个回答
4
HYRY的解决方案是正确的,但我总觉得numpy中的 r_
、c_
和 s_
这些索引方式看起来有点奇怪。所以这里用一个 slice
对象来做同样的事情:
x[(slice(None),) + t]
这个slice的单个参数是结束位置(也就是说,None
表示全部,和 x[:]
等价于 x[None:None]
是一样的)
9
你可以先创建索引元组:
index = (numpy.s_[:],)+t
x[index]