PyTorch如何在多个维度上进行聚集

2024-03-28 15:22:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我想找到一种不用for循环的方法

假设我有一个多维张量t0

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

它的形状是:torch.Size([4, 10, 16])

我有另一个张量labels,它是seq维中5个随机指数的一批:

labels = torch.randint(0, seq, size=[bs, sample])

这就是形状torch.Size([4, 5])。这用于索引t0seq维度

我想做的是使用labels张量在批处理维度上循环进行聚集。我的暴力解决方案是:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

导致张量t1的形状:torch.Size([4, 5, 16])

在pytorch中有没有更惯用的方法


Tags: sample方法inforsizelabelsbstorch
2条回答

你可以这样做:

t1 = t0[[[b] for b in range(bs)], labels]

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

您可以在这里使用fancy indexing来选择所需的张量部分

本质上,如果预先生成传递访问模式的索引数组,则可以直接使用它们提取张量的某些片段。每个维度的索引数组的形状应与要提取的输出张量或切片的形状相同

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]

# Get result as
t1 = t0[i, j, k]

注意上面3个张量的形状Broadcasting在张量的前面附加额外的维度,从而从本质上将k重塑为[1, 1, v]形状,这使得所有3个维度都兼容元素操作

广播后(i, j, k)一起将产生3[bs, sample, v]形状的数组,这些数组将(按元素)索引原始张量以产生形状[bs, sample, v]的输出张量t1

相关问题 更多 >