如何在Pytorch中获取2D张量中每行不同的top-k索引?
给定:
- 一个正整数张量 A,形状为 (batch_size, N),其中零是最小值。例如:
tensor([[4, 3, 1, 4, 2],
[0, 0, 2, 3, 4],
[4, 4, 3, 0, 3]])
我想要获取每一行中第 k 大值的不同索引?
- k 是一个包含
batch_size
个元素的列表,这些元素是随机选择的,值只表示两种情况:第一种是最大的(所以 k = 1),第二种是最小的,但忽略零。例如,如果这一行是 [2,3,4,0],那么最小的索引是 0(值为 2)。(选择最大的概率是 0.7,选择最小的概率是 0.3)
根据上面的例子,如果 k = [1,0,0]
(1 表示获取最大的,0 表示获取最小的)那么输出的索引将是
output = [0, 2, 2]
对应的值是 [4, 2, 3]
注意:请将这些计算向量化。
1 个回答
4
你可以这样做
import torch
x = torch.tensor([[4, 3, 1, 4, 2],
[0, 0, 2, 3, 4],
[4, 4, 3, 0, 3]]).float() # float required for later ops
k = torch.tensor([1, 0, 0]).long()
# set 0 to -1, ie [1, -1, -1]
k_sign = k + (-1 * (k==0).float())
# flip sign for rows where we want the smallest nonzero index
x_signed = x * k_sign.unsqueeze(1)
# fill zeros with -inf
x_filled = x_signed.masked_fill(x==0, float('-inf'))
# grab topk index of each row
_, output = x_filled.topk(1, dim=1)
output = output.squeeze()
output
> tensor([0, 2, 2])