从一维张量中提取topk值指数

2024-03-28 20:45:01 发布

您现在位置:Python中文网/ 问答频道 /正文

给定Torch(torch.Tensor)中的一维张量,其中包含可比较的值(比如浮点),我们如何提取该张量中顶部k值的索引

除了蛮力方法之外,我正在寻找Torch/lua提供的一些API调用,它们可以有效地执行此任务


Tags: 方法apitorch浮点蛮力luatensor
3条回答

只需循环使用张量并进行比较:

require 'torch'

data = torch.Tensor({1,2,3,4,505,6,7,8,9,10,11,12})
idx  = 1
max  = data[1]

for i=1,data:size()[1] do
   if data[i]>max then
      max=data[i]
      idx=i
   end
end

print(idx,max)

编辑 响应您的编辑:使用此处记录的torch.max操作:https://github.com/torch/torch7/blob/master/doc/maths.md#torchmaxresval-resind-x-dim

y, i = torch.max(x, 1) returns the largest element in each column (across rows) of x, and a Tensor i of their corresponding indices in x

您可以使用topk函数

例如:

import torch

t = torch.tensor([5.7, 1.4, 9.5, 1.6, 6.1, 4.3])

values,indices = t.topk(2)

print(values)
print(indices)

结果是:

tensor([9.5000, 6.1000])
tensor([2, 4])

从pull请求#496起,Torch现在包括一个名为^{}的内置API。例如:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

  obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

  you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

  alternatively you can obtain the k largest elements as follow
  (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]

在编写本文时,CPU实现遵循sort and narrow approach(有计划在将来改进它)。也就是说,目前正在为cutorch优化GPU实现

相关问题 更多 >