在NumPy中使用多层布尔索引掩码
我有一段代码,首先是用一个逻辑索引掩码来选择NumPy数组中的元素:
import numpy as np
grid = np.random.rand(4,4)
mask = grid > 0.5
接下来,我想用第二个布尔掩码来从中挑选出符合某些条件的对象:
masklength = len(grid[mask])
prob = 0.5
# generates an random array of bools
second_mask = np.random.rand(masklength) < prob
# this fails to act on original object
grid[mask][second_mask] = 100
这个问题和StackOverflow上提到的那个问题不太一样:
Numpy数组,如何选择满足多个条件的索引? - 因为我在使用随机数生成,我不想生成一个完整的掩码,只想针对第一个掩码选中的元素进行处理。5 个回答
2
我在尝试做类似的事情时,发现了这个旧的讨论串。这里有很多有趣的回答,但我觉得我想出了一个更简单、更直观的方法:就是把第一个掩码(mask)用在自己身上,来根据需要把True值变成True或False。
方法很简单,只需要一行代码,然后就可以根据需要使用这个掩码了:
mask[mask] = second_mask
11
使用扁平索引可以避免很多麻烦:
grid.flat[np.flatnonzero(mask)[second_mask]] = 100
我们来分解一下:
ind = np.flatnonzero(mask)
这个过程生成了一个扁平的索引数组,只有在 mask
为真的地方才会有值,然后再通过应用 second_mask
进一步筛选:
ind = ind[second_mask]
我们可以继续深入讲:
ind = ind[third_mask]
最后
grid.flat[ind] = 100
用 ind
对 grid
的扁平版本进行索引,并赋值为 100
。 grid.ravel()[ind] = 100
也可以实现这个效果,因为 ravel()
会返回原始数组的扁平视图。
7
我觉得下面的代码可以满足你的需求:
grid[[a[second_mask] for a in np.where(mask)]] = 100
它的工作原理如下:
np.where(mask)
会把布尔掩码转换成那些mask
为真的位置索引;[a[second_mask] for a in ...]
会根据索引,只选择那些second_mask
为真的位置。
你最开始的版本不工作的原因是 grid[mask]
使用了复杂的索引。这会创建数据的一个副本,结果是 ...[second_mask] = 100
修改的是这个副本,而不是原始数组。