Numpy 多维花式索引
假设我有一个大小为 n x m x k 的 numpy 数组 A,还有一个大小为 n x m 的数组 B,B 中的值是从 1 到 k 的索引。我想通过 B 中的索引来访问 A 的每个 n x m 切片,这样我就能得到一个大小为 n x m 的数组。
编辑:看来这并不是我想要的!我可以用 take
这样来实现:
A.take(B)
结束编辑
我能不能用更复杂的索引来实现这个?我本以为 A[B]
会给我相同的结果,但结果却是一个大小为 n x m x m x k 的数组(我对此有点不理解)。
我不想用 take
的原因是我想给这个部分赋值,比如:
A[B] = 1
到目前为止,我唯一能用的解决方案是:
A.reshape(-1, k)[np.arange(n * m), B.ravel()].reshape(n, m)
但肯定还有更简单的方法吧?
1 个回答
3
假设
import numpy as np
np.random.seed(0)
n,m,k = 2,3,5
A = np.arange(n*m*k,0,-1).reshape((n,m,k))
print(A)
# [[[30 29 28 27 26]
# [25 24 23 22 21]
# [20 19 18 17 16]]
# [[15 14 13 12 11]
# [10 9 8 7 6]
# [ 5 4 3 2 1]]]
B = np.random.randint(k, size=(n,m))
print(B)
# [[4 0 3]
# [3 3 1]]
为了创建这个数组,
print(A.reshape(-1, k)[np.arange(n * m), B.ravel()])
# [26 25 17 12 7 4]
我们要用一种叫做花式索引的方法来生成一个 nxm
的数组:
i,j = np.ogrid[0:n, 0:m]
print(A[i, j, B])
# [[26 25 17]
# [12 7 4]]