Numpy 多维花式索引

5 投票
1 回答
3432 浏览
提问于 2025-04-17 09:36

假设我有一个大小为 n x m x k 的 numpy 数组 A,还有一个大小为 n x m 的数组 B,B 中的值是从 1 到 k 的索引。我想通过 B 中的索引来访问 A 的每个 n x m 切片,这样我就能得到一个大小为 n x m 的数组。

编辑:看来这并不是我想要的!我可以用 take 这样来实现:

A.take(B)

结束编辑

我能不能用更复杂的索引来实现这个?我本以为 A[B] 会给我相同的结果,但结果却是一个大小为 n x m x m x k 的数组(我对此有点不理解)。

我不想用 take 的原因是我想给这个部分赋值,比如:

A[B] = 1

到目前为止,我唯一能用的解决方案是:

A.reshape(-1, k)[np.arange(n * m), B.ravel()].reshape(n, m)

但肯定还有更简单的方法吧?

1 个回答

3

假设

import numpy as np
np.random.seed(0)

n,m,k = 2,3,5
A = np.arange(n*m*k,0,-1).reshape((n,m,k))
print(A)
# [[[30 29 28 27 26]
#   [25 24 23 22 21]
#   [20 19 18 17 16]]

#  [[15 14 13 12 11]
#   [10  9  8  7  6]
#   [ 5  4  3  2  1]]]

B = np.random.randint(k, size=(n,m))
print(B)
# [[4 0 3]
#  [3 3 1]]

为了创建这个数组,

print(A.reshape(-1, k)[np.arange(n * m), B.ravel()])
# [26 25 17 12  7  4]

我们要用一种叫做花式索引的方法来生成一个 nxm 的数组:

i,j = np.ogrid[0:n, 0:m]
print(A[i, j, B])
# [[26 25 17]
#  [12  7  4]]

撰写回答