在NumPy中使用多层布尔索引掩码

14 投票
5 回答
8790 浏览
提问于 2025-04-17 00:16

我有一段代码,首先是用一个逻辑索引掩码来选择NumPy数组中的元素:

import numpy as np

grid = np.random.rand(4,4) 
mask = grid > 0.5

接下来,我想用第二个布尔掩码来从中挑选出符合某些条件的对象:

masklength = len(grid[mask])
prob = 0.5
# generates an random array of bools
second_mask = np.random.rand(masklength) < prob 

# this fails to act on original object
grid[mask][second_mask] = 100

这个问题和StackOverflow上提到的那个问题不太一样:

Numpy数组,如何选择满足多个条件的索引? - 因为我在使用随机数生成,我不想生成一个完整的掩码,只想针对第一个掩码选中的元素进行处理。

5 个回答

2

我在尝试做类似的事情时,发现了这个旧的讨论串。这里有很多有趣的回答,但我觉得我想出了一个更简单、更直观的方法:就是把第一个掩码(mask)用在自己身上,来根据需要把True值变成True或False。

方法很简单,只需要一行代码,然后就可以根据需要使用这个掩码了:

mask[mask] = second_mask
11

使用扁平索引可以避免很多麻烦:

grid.flat[np.flatnonzero(mask)[second_mask]] = 100

我们来分解一下:

ind = np.flatnonzero(mask)

这个过程生成了一个扁平的索引数组,只有在 mask 为真的地方才会有值,然后再通过应用 second_mask 进一步筛选:

ind = ind[second_mask] 

我们可以继续深入讲:

ind = ind[third_mask]

最后

grid.flat[ind] = 100

indgrid 的扁平版本进行索引,并赋值为 100grid.ravel()[ind] = 100 也可以实现这个效果,因为 ravel() 会返回原始数组的扁平视图。

7

我觉得下面的代码可以满足你的需求:

grid[[a[second_mask] for a in np.where(mask)]] = 100

它的工作原理如下:

  • np.where(mask) 会把布尔掩码转换成那些 mask 为真的位置索引;
  • [a[second_mask] for a in ...] 会根据索引,只选择那些 second_mask 为真的位置。

你最开始的版本不工作的原因是 grid[mask] 使用了复杂的索引。这会创建数据的一个副本,结果是 ...[second_mask] = 100 修改的是这个副本,而不是原始数组。

撰写回答