获取张量a中存在于张量b中的元素的索引

2024-03-28 17:34:34 发布

您现在位置:Python中文网/ 问答频道 /正文

例如,我想得到张量a中值为0和2的元素的索引。这些值(0和2)存储在张量b中。我设计了一种pythonic方法来实现这一点(如下所示),但我认为列表理解并没有优化到可以在GPU上运行,或者可能有一种我不知道的更pythorchy的方法来实现

import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()

>>>> tensor([[0],
             [2],
             [5],
             [6]])

还有其他建议吗?或者这是一种可以接受的方式吗


Tags: 方法inimport元素列表forgpu方式
1条回答
网友
1楼 · 发布于 2024-03-28 17:34:34

这里有一个更有效的方法(正如jodag在评论中发布的链接所建议的那样…):

(a[..., None] == b).any(-1).nonzero()

相关问题 更多 >