如何将pytorch张量中一行中的最大n个元素设置为1,并将其设置为零?

2024-04-24 21:35:49 发布

您现在位置:Python中文网/ 问答频道 /正文

例如,如果我有张量

0.8, 0.1, 0.9, 0.2
0.7, 0.1, 0.4, 0.6

n=2,我想得到

1, 0, 1, 0
1, 0, 0, 1

或者也许做得很简单更好,但问题是一样的


3条回答

为了提高性能效率,我们可以使用^{}-

def partition_assign(a, n):
    idx = np.argpartition(a,-n,axis=1)[:,-n:]
    out = np.zeros(a.shape, dtype=int)
    np.put_along_axis(out,idx,1,axis=1)
    return out

或者我们可以在argpartition步骤中使用np.argpartition(-a,n,axis=1)[:,:n]

样本运行-

In [56]: a
Out[56]: 
array([[0.8, 0.1, 0.9, 0.2],
       [0.7, 0.1, 0.4, 0.6]])

In [57]: partition_assign(a, n=2)
Out[57]: 
array([[1, 0, 1, 0],
       [1, 0, 0, 1]])

In [58]: partition_assign(a, n=3)
Out[58]: 
array([[1, 0, 1, 1],
       [1, 0, 1, 1]])

您可以使用^{}获取索引,然后将其设置为1以获得新的张量

t = torch.tensor([[0.8, 0.1, 0.9, 0.2],[0.7, 0.1, 0.4, 0.6]])
tb = torch.zeros(t.shape)                      # create new tensor for 1,0

# set `1` to topk indices
tb[(torch.arange(len(t)).unsqueeze(1), torch.topk(t,2).indices)] = 1
tb
tensor([[1., 0., 1., 0.],
        [1., 0., 0., 1.]])

n=3

tb = torch.zeros(t.shape)
tb[(torch.arange(len(t)).unsqueeze(1), torch.topk(t,3).indices)] = 1
tb
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.]])

你可以使用列表理解。 简单快速

首先选择每个数组的阈值(基于n

def tensor_threshold(tensor, n):
    thresholds = [sorted(arr)[-n] for arr in tensor]
    return [[0 if x < th else 1 for x in arr]
            for arr, th in zip(tensor, thresholds)]


T = [[0.8, 0.1, 0.9, 0.2], [0.7, 0.1, 0.4, 0.6]]

tensor_threshold(T, n=2)
>>> [[1, 0, 1, 0], [1, 0, 0, 1]]

相关问题 更多 >