如何切片一个numpy数组?

4 投票
2 回答
3366 浏览
提问于 2025-04-17 12:00

m 是一个形状为 (12, 21, 21) 的 ndarray,现在我想从中提取一个稀疏的切片,形成一个新的二维数组,具体是这样的:

sliceid = 0
indx    = np.array([0, 2, 4, 6, 8, 10])

这样 sparse_slice 的直观结果应该是:

sparse_slice = m[sliceid, indx, indx]

但是显然上面的操作并没有成功,目前我使用的方法是:

sparse_slice = m[sliceid,indx,:][:, indx]

为什么第一个“直观”的方法不奏效?有没有比我现在的解决方案更简洁的方法?我之前的 ndarray 切片尝试完全是凭直觉,或许我该去看看一些正规的手册了……

2 个回答

2

如果我没记错的话,对于输入 m = np.array(range(5292)).reshape(12,21,21),你期望的输出是 sparse_slice = m[sliceid,indx,:][:, indx] 的结果是

array([[  0,   2,   4,   6,   8,  10],
       [ 42,  44,  46,  48,  50,  52],
       [ 84,  86,  88,  90,  92,  94],
       [126, 128, 130, 132, 134, 136],
       [168, 170, 172, 174, 176, 178],
       [210, 212, 214, 216, 218, 220]])

在这种情况下,你可以使用切片的 step 部分来实现这个效果:

m[0, :12:2, :12:2]

5

更简洁的写法是用 new = m[0, :12:2, :12:2]。这就是numpy文档中提到的“基本索引”,意思是你用整数或切片对象(比如0:12:2)来切片。当你使用基本索引时,numpy会返回原始数组的一个视图。例如:

In [3]: a = np.zeros((2, 3, 4))

In [4]: b = a[0, 1, ::2]

In [5]: b
Out[5]: array([ 0.,  0.])

In [6]: b[:] = 7

In [7]: a
Out[7]: 
array([[[ 0.,  0.,  0.,  0.],
        [ 7.,  0.,  7.,  0.],
        [ 0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.]]])

在你“直观”的方法中,你是用一个数组来索引另一个数组。当你用一个数组去索引numpy数组时,这两个数组需要大小相同(或者它们需要能够相互广播,稍后会详细讲)。在文档中,这被称为花式索引或高级索引。例如:

In [10]: a = np.arange(9).reshape(3,3)

In [11]: a
Out[11]: 
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

In [12]: index = np.array([0,1,2])

In [13]: b = a[index, index]

In [14]: b
Out[14]: array([0, 4, 8])

你会看到我得到了a[0,0]、a[1,1]和a[2,2],而不是a[0,0]、a[0,1]……如果你想要索引的“外积”,可以这样做。

In [22]: index1 = np.array([[0,0],[1,1]])

In [23]: index2 = np.array([[0,1],[0,1]])

In [24]: b = a[index1, index2]

In [25]: b
Out[25]: 
array([[0, 1],
       [3, 4]])

上面的操作有个简写方式,像这样:

In [28]: index = np.array([0,1])

In [29]: index1, index2 = np.ix_(index, index)

In [31]: index1
Out[31]: 
array([[0],
       [1]])

In [32]: index2
Out[32]: array([[0, 1]])

In [33]: a[index1, index2]
Out[33]: 
array([[0, 1],
       [3, 4]])

In [34]: a[np.ix_(index, index)]
Out[34]: 
array([[0, 1],
       [3, 4]])

你会注意到 index1(2, 1),而 index2(1, 2),而不是 (2, 2)。这是因为这两个数组会相互广播,你可以在这里了解更多关于广播的内容。记住,当你使用花式索引时,你得到的是原始数据的一个副本,而不是视图。有时候这样更好(如果你想保持原始数据不变),但有时候这会占用更多内存。关于索引的更多信息可以在这里找到。

撰写回答