如何在Pytorch中获取2D张量中每行不同的top-k索引?

0 投票
1 回答
50 浏览
提问于 2025-04-14 17:45

给定:

  • 一个正整数张量 A,形状为 (batch_size, N),其中零是最小值。例如:
tensor([[4, 3, 1, 4, 2],
        [0, 0, 2, 3, 4],
        [4, 4, 3, 0, 3]])

我想要获取每一行中第 k 大值的不同索引?

  • k 是一个包含 batch_size 个元素的列表,这些元素是随机选择的,值只表示两种情况:第一种是最大的(所以 k = 1),第二种是最小的,但忽略零。例如,如果这一行是 [2,3,4,0],那么最小的索引是 0(值为 2)。(选择最大的概率是 0.7,选择最小的概率是 0.3)

根据上面的例子,如果 k = [1,0,0](1 表示获取最大的,0 表示获取最小的)那么输出的索引将是

output = [0, 2, 2] 对应的值是 [4, 2, 3]

注意:请将这些计算向量化。

1 个回答

4

你可以这样做

import torch

x = torch.tensor([[4, 3, 1, 4, 2],
                  [0, 0, 2, 3, 4],
                  [4, 4, 3, 0, 3]]).float() # float required for later ops

k = torch.tensor([1, 0, 0]).long()

# set 0 to -1, ie [1, -1, -1]
k_sign = k + (-1 * (k==0).float())

# flip sign for rows where we want the smallest nonzero index
x_signed = x * k_sign.unsqueeze(1)

# fill zeros with -inf
x_filled = x_signed.masked_fill(x==0, float('-inf'))

# grab topk index of each row
_, output = x_filled.topk(1, dim=1)

output = output.squeeze()

output
> tensor([0, 2, 2])

撰写回答