在pytorch中通过两个后续的掩码操作赋值
我生成了两个不同的掩码,这些掩码是根据一些数值创建的:
import torch
values = torch.tensor([0, 0.5, 0.99, 0.87])
saved_values = values + torch.tensor([0.1, -0.4, 0, 0.1])
result = torch.zeros_like(values)
mask1 = values > 0
mask2 = ~torch.greater(saved_values[mask1], values[mask1])
现在我测试了一下,看看能不能用这些掩码来提取数据:
>>> result[mask1][mask2]
tensor([0., 0.])
看起来是可以的,所以我用广播的方式继续测试:
>>> result[:] = 5
>>> result[mask1][mask2]
tensor([5., 5.])
这也似乎没问题,所以我最后用掩码来测试这些数值:
>>> values[mask1][mask2]
tensor([0.5000, 0.9900])
结果也没问题,所以我尝试根据掩码来给这些数值赋值:
result = torch.zeros_like(values)
result[mask1][mask2] = values[mask1][mask2]
没有报错,所以我认为它是有效的,检查了两次:
>>> result[mask1][mask2]
tensor([0., 0.])
>>> result
tensor([0., 0., 0., 0.])
但似乎由于一些引用的问题,数值没有正确保存。
我该如何实现我想要的效果呢?
1 个回答
1
问题在于你进行了两次索引,目前 mask2
只有在 mask1
应用之后才会起作用。为什么不把 mask2
定义为不受 mask1 影响的呢?
>>> mask = (values > 0)*(~torch.greater(saved_values, values))
这样你就可以把两个掩码结合起来,只需要进行一次掩码操作:
>>> result[mask] = values[mask]
tensor([0.0000, 0.5000, 0.9900, 0.0000])
如果你 只能 / 只想 在用 mask1
进行掩码处理后再计算 mask2
,你仍然可以通过使用 nonzero
索引来构建结果掩码:
>>> mask = torch.zeros_like(mask1)
>>> mask[mask1.nonzero()[:,0]] = mask2
更简洁的实现方式是使用 torch.scatter_
:
>>> torch.zeros_like(mask1).scatter_(0, mask1.nonzero()[:,0], mask2)