从另一个2D数组中提取元素和索引的2D数组
我有一个二维的numpy数组,叫做data,它的形状是(n,8),还有另一个数组ind,它的形状是(n,4)。这两个数组的长度是一样的,ind数组里面包含了一些索引,比如说[4,3,0,6]。我想要创建一个新的数组,它的形状是(n,4),这个新数组里的元素是从data中根据ind里的索引提取出来的。我的实际数组很长(shape[0]),所以用循环来处理会很慢。有没有比循环更好的方法呢?
import numpy as np
# Example data
data = np.array([[ 0.44180102, -0.05941365, 2.1482739 , -0.56875081, -1.45400572,
-1.44391254, -0.33710766, -0.44214518],
[ 0.79506417, -2.46156966, -0.09929341, -1.07347179, 1.03986533,
-0.45745476, 0.58853107, -1.08565425],
[ 1.40348682, -1.43396403, 0.8267174 , -1.54812358, -1.05854445,
0.15789466, -0.0666025 , 0.29058816]])
ind = np.array([[3, 4, 1, 5],
[4, 7, 0, 1],
[5, 1, 3, 6]])
# This is the part I want to vectorize:
out = np.zeros(ind.shape)
for i in range(ind.shape[0]):
out[i,:] = data[i,ind[i,:]]
# This should be good
assert np.all(out == np.array([[-0.56875081, -1.45400572, -0.05941365, -1.44391254],
[ 1.03986533, -1.08565425, 0.79506417, -2.46156966],
[ 0.15789466, -1.43396403, -1.54812358, -0.0666025 ]]))
2 个回答
3
你想要的结果大概是这样的:
import numpy as np
data = np.array([[ 0.4, -0.1, 2.1, -0.6, -1.5, -1.4, -0.3, -0.4],
[ 0.8, -2.5, -0.1, -1.1, 1. , -0.5, 0.6, -1.1],
[ 1.4, -1.4, 0.8, -1.5, -1.1, 0.2, -0.1, 0.3]])
expected = np.array([[-0.6, -1.5, -0.1, -1.4],
[ 1. , -1.1, 0.8, -2.5],
[ 0.2, -1.4, -1.5, -0.1]])
indI = np.array([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]])
indJ = np.array([[3, 4, 1, 5],
[4, 7, 0, 1],
[5, 1, 3, 6]])
out = data[indI, indJ]
assert np.all(out == expected)
注意到 indI
和 indJ
的形状是一样的,并且
out[i, j] == data[indI[i, j], indJ[i, j]]
对于所有的 i
和 j
来说都是如此。
你可能会发现 indI
的内容非常重复。由于 numpy 的 广播 特性,你可以简单地将 indI
改成:
indI = np.array([[0],
[1],
[2]])
你可以用几种不同的方法来构建这种类型的 indI
数组,这里是我最喜欢的一种:
a, b = indJ.shape
indI, _ = np.ogrid[:a, :0]
out = data[indI, indJ]
5
如果我们直接从展开的 data
数组中提取数据,这个操作会变得很简单:
out = data.ravel()[ind.ravel() + np.repeat(range(0, 8*ind.shape[0], 8), ind.shape[1])].reshape(ind.shape)
解释
把这个过程分成三个步骤,可能会更容易理解:
indices = ind.ravel() + np.repeat(range(0, 8*ind.shape[0], 8), ind.shape[1])
out = data.ravel()[indices]
out = out.reshape(ind.shape)
ind
里包含了我们想要从 data
中提取的元素的信息。不过,这些信息是以二维的方式表示的。第一行代码将这些二维的索引转换成一维展开的 data
的索引。第二行代码则从展开的 data
数组中选择出这些元素。第三行代码把提取出来的元素恢复成二维的形状,存放在 out
中。
用 ind
表示的二维索引被转换成了一维的 indices
,这些就是我们需要的索引。