获取批处理聚集的智能方法是什么?

2024-04-24 15:24:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个矩阵,AB,分别具有(n, m, k)(n, m)的形状n是批大小,m是批中的数据量,k是特征大小

B的每个元素都是小于m(特别是B = torch.randint(high=m, shape=(n,m)))的索引

我想以更智能的方式实现[A[i][B[i]] for i in range(n)]

在pytorch中有没有更好的方法来实现这一点而不进行for循环


Tags: in元素for智能方式range矩阵torch
1条回答
网友
1楼 · 发布于 2024-04-24 15:24:49

你可以用

a[torch.arange(n)[:, None], b]

例如:

>>> n, m, k = 3, 2, 5
>>> a = torch.arange(30).view(n, m, k)
>>> b = torch.randint(high=m, size=(n,m))

# first indexer (of shape (n, 1))
>>> torch.arange(n)[:, None]

tensor([[0],
        [1],
        [2]])

# second indexer
>>> b

tensor([[1, 0],
        [0, 1],
        [1, 1]])

索引器的形状分别为(3, 1)(3, 2),因此它们将被广播到(3, 2)以有效地

tensor([[0, 0],
        [1, 1],
        [2, 2]])

tensor([[1, 0],
        [0, 1],
        [1, 1]])

这表示:对于第一行,取第一个(k,)数组并放入结果,取第0个(k,)数组并放入结果。这将填充输出中的(m, k)数组,每行重复n

得到

>>> a[torch.arange(n)[:, None], b]

tensor([[[ 5,  6,  7,  8,  9],
         [ 0,  1,  2,  3,  4]],

        [[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[25, 26, 27, 28, 29],
         [25, 26, 27, 28, 29]]])

与列表理解相比:

>>> [a[i][b[i]] for i in range(n)]

[tensor([[5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4]]),
 tensor([[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]]),
 tensor([[25, 26, 27, 28, 29],
         [25, 26, 27, 28, 29]])]

相关问题 更多 >