在pytorch中通过两个后续的掩码操作赋值

0 投票
1 回答
36 浏览
提问于 2025-04-14 18:35

我生成了两个不同的掩码,这些掩码是根据一些数值创建的:

import torch

values = torch.tensor([0, 0.5, 0.99, 0.87])
saved_values = values + torch.tensor([0.1, -0.4, 0, 0.1])
result = torch.zeros_like(values)

mask1 = values > 0
mask2 = ~torch.greater(saved_values[mask1], values[mask1])

现在我测试了一下,看看能不能用这些掩码来提取数据:

>>> result[mask1][mask2]
tensor([0., 0.])

看起来是可以的,所以我用广播的方式继续测试:

>>> result[:] = 5
>>> result[mask1][mask2]
tensor([5., 5.])

这也似乎没问题,所以我最后用掩码来测试这些数值:

>>> values[mask1][mask2]
tensor([0.5000, 0.9900])

结果也没问题,所以我尝试根据掩码来给这些数值赋值:

result = torch.zeros_like(values)
result[mask1][mask2] = values[mask1][mask2]

没有报错,所以我认为它是有效的,检查了两次:

>>> result[mask1][mask2]
tensor([0., 0.])
>>> result
tensor([0., 0., 0., 0.])

但似乎由于一些引用的问题,数值没有正确保存。

我该如何实现我想要的效果呢?

1 个回答

1

问题在于你进行了两次索引,目前 mask2 只有在 mask1 应用之后才会起作用。为什么不把 mask2 定义为不受 mask1 影响的呢?

>>> mask = (values > 0)*(~torch.greater(saved_values, values))

这样你就可以把两个掩码结合起来,只需要进行一次掩码操作:

>>> result[mask] = values[mask]
tensor([0.0000, 0.5000, 0.9900, 0.0000])

如果你 只能 / 只想 在用 mask1 进行掩码处理后再计算 mask2,你仍然可以通过使用 nonzero 索引来构建结果掩码:

>>> mask = torch.zeros_like(mask1)
>>> mask[mask1.nonzero()[:,0]] = mask2

更简洁的实现方式是使用 torch.scatter_

>>> torch.zeros_like(mask1).scatter_(0, mask1.nonzero()[:,0], mask2)

撰写回答