对于给定的2D张量,我想检索值为1
的所有索引。我希望能够简单地使用torch.nonzero(a == 1).squeeze()
,它将返回tensor([1, 3, 2])
。但是,相反,torch.nonzero(a == 1)
返回一个2D张量(没关系),每行有两个值(这不是我预期的)。然后应该使用返回的索引来索引3D张量的第二维度(索引1),再次返回2D张量。在
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])
print(b.gather(1, idxs))
显然,这不起作用,导致出现错误:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453
似乎idxs
不是我所期望的,也不能像我想的那样使用它。idxs
是
但是通读documentation我不明白为什么我也会得到结果张量中的行索引。现在,我知道我可以通过切片idxs[:, 1]
得到正确的idx,但是,我仍然不能使用这些值作为3D张量的索引,因为与之前一样的错误被提出。是否可以使用索引的一维张量来选择给定维度上的项目?在
假设
b
的三个维度是batch_size x sequence_length x features
(bxs x feats),则可以获得如下预期结果。在您可以简单地将它们切片并作为索引传递,如:
或者,一种更简单的方法是只使用^{} ,然后直接索引到张量{},如下所示:
^{pr2}$关于上述使用
torch.where()
的方法的更多解释:它基于advanced indexing的概念工作。也就是说,当我们使用序列对象的元组(如张量元组、列表元组、元组等)索引到张量中时对于基本切片,我们需要一个整数索引元组:
要使用高级索引实现相同的功能,我们需要序列对象的元组:
返回张量的维数总是比输入张量的维数小一个维数。在
相关问题 更多 >
编程相关推荐