在numpy中使用掩码数组索引

6 投票
2 回答
6176 浏览
提问于 2025-04-16 04:58

我有一段代码,想要根据另一个数组里的索引来找到一个数组的内容,但有可能这些索引超出了第一个数组的范围。

input = np.arange(0, 5)
indices = np.array([0, 1, 2, 99])

我想做的是这样: 打印 input[indices] 然后得到 [0 1 2]

但这样会出现一个异常(这是可以预料的):

IndexError: index 99 out of bounds 0<=index<5

所以我想用掩码数组来隐藏那些超出范围的索引:

indices = np.ma.masked_greater_equal(indices, 5)

但结果还是:

>print input[indices]
IndexError: index 99 out of bounds 0<=index<5

尽管:

>np.max(indices)
2

所以我必须先填充这个掩码数组,这让我很烦,因为我不知道该用什么填充值来避免选择那些超出范围的索引:

打印 input[np.ma.filled(indices, 0)]

[0 1 2 0]

所以我的问题是:如何有效地使用numpy,从一个数组中安全地选择索引,而不超出输入数组的范围呢?

2 个回答

4

用带掩码的数组来索引是个非常糟糕的主意。曾经有一段时间,使用带掩码的数组进行索引会抛出异常,但那样有点太严格了……

在你的测试中,你正在过滤 indices 来找到符合条件的条目。那么,对于你的带掩码数组中缺失的条目,你该怎么处理呢?这个条件是“假”吗?还是“真”?你是否应该使用一个默认值?这完全取决于你,作为用户,来决定该怎么做。

使用 indices.filled(0) 意味着当 indices 中的某个项目被掩盖(也就是未定义)时,你想用第一个索引(0)作为默认值。这可能并不是你想要的结果。

在这里,我会简单地使用 input[indices.compressed()]:这个 compressed 方法会把你的带掩码数组压缩,只保留那些未被掩盖的条目。

不过,正如你意识到的,你可能根本就不需要带掩码的数组。

5

如果不使用掩码数组,你可以这样去掉大于或等于5的索引:

print input[indices[indices<5]]

补充说明:如果你还想去掉负的索引,可以这样写:

print input[indices[(0 <= indices) & (indices < 5)]]

撰写回答