在numpy中使用掩码数组索引
我有一段代码,想要根据另一个数组里的索引来找到一个数组的内容,但有可能这些索引超出了第一个数组的范围。
input = np.arange(0, 5)
indices = np.array([0, 1, 2, 99])
我想做的是这样: 打印 input[indices] 然后得到 [0 1 2]
但这样会出现一个异常(这是可以预料的):
IndexError: index 99 out of bounds 0<=index<5
所以我想用掩码数组来隐藏那些超出范围的索引:
indices = np.ma.masked_greater_equal(indices, 5)
但结果还是:
>print input[indices]
IndexError: index 99 out of bounds 0<=index<5
尽管:
>np.max(indices)
2
所以我必须先填充这个掩码数组,这让我很烦,因为我不知道该用什么填充值来避免选择那些超出范围的索引:
打印 input[np.ma.filled(indices, 0)]
[0 1 2 0]
所以我的问题是:如何有效地使用numpy,从一个数组中安全地选择索引,而不超出输入数组的范围呢?
2 个回答
4
用带掩码的数组来索引是个非常糟糕的主意。曾经有一段时间,使用带掩码的数组进行索引会抛出异常,但那样有点太严格了……
在你的测试中,你正在过滤 indices
来找到符合条件的条目。那么,对于你的带掩码数组中缺失的条目,你该怎么处理呢?这个条件是“假”吗?还是“真”?你是否应该使用一个默认值?这完全取决于你,作为用户,来决定该怎么做。
使用 indices.filled(0)
意味着当 indices
中的某个项目被掩盖(也就是未定义)时,你想用第一个索引(0)作为默认值。这可能并不是你想要的结果。
在这里,我会简单地使用 input[indices.compressed()]
:这个 compressed
方法会把你的带掩码数组压缩,只保留那些未被掩盖的条目。
不过,正如你意识到的,你可能根本就不需要带掩码的数组。
5
如果不使用掩码数组,你可以这样去掉大于或等于5的索引:
print input[indices[indices<5]]
补充说明:如果你还想去掉负的索引,可以这样写:
print input[indices[(0 <= indices) & (indices < 5)]]