Numpy多维数组切片

8 投票
2 回答
6029 浏览
提问于 2025-04-17 01:37

假设我定义了一个3x3x3的numpy数组,代码如下:

x = numpy.arange(27).reshape((3, 3, 3))

现在,我可以通过 x[:, 0, 1] 来获取每个3x3子数组中(0,1)这个位置的元素,这样会返回 array([ 1, 10, 19])。如果我有一个元组(m,n),想要获取每个子数组中(m,n)这个位置的元素,该怎么做呢?

比如,我有 t = (0, 1)。我尝试了 x[:, t],但是结果不对——它返回的是每个子数组的第0行和第1行。到目前为止,我找到的最简单的解决办法是:

x.transpose()[tuple(reversed(t))].transpose()

不过我相信一定还有更好的方法。当然,在这种情况下,我可以用 x[:, t[0], t[1]] 来实现,但这不能推广到我不知道 xt 有多少维度的情况。

2 个回答

4

HYRY的解决方案是正确的,但我总觉得numpy中的 r_c_s_ 这些索引方式看起来有点奇怪。所以这里用一个 slice 对象来做同样的事情:

x[(slice(None),) + t]

这个slice的单个参数是结束位置(也就是说,None表示全部,和 x[:] 等价于 x[None:None] 是一样的)

9

你可以先创建索引元组:

index = (numpy.s_[:],)+t 
x[index]

撰写回答