Pytorch张量获取具有特定值的元素的索引?

2024-03-29 01:20:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个张量,张量a和张量b

我想得到张量b中所有值的索引

比如说

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])

我想要张量a中的1, 2, 4的索引。我可以通过以下代码来实现这一点

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
    mask = mask + (a == e)
    print(mask)

没有for我怎么做


Tags: 代码infortypezerosmasktorchbool
2条回答

如果您不想使用for循环,可以使用列表理解:

mask = [a[index] for index in b]

如果你甚至不想使用“for”这个词,你可以将张量转换成numpy,并使用numpy索引

mask = torch.tensor(a.numpy()[b.numpy()])

更新

可能误解了你的问题。在这种情况下,我想说实现这一点的最佳方法是通过列表理解。(切片可能无法实现这一点

mask = [index for index,value in enumerate(a) if value in b.tolist()] 

它迭代a中的每个元素,获取它们的索引和值,如果值在b中,则获取索引

这就是你想要的吗

np.in1d(a.numpy(), b.numpy())

将导致:

array([ True,  True,  True, False,  True,  True,  True, False])

相关问题 更多 >