我想找到一种不用for循环的方法
假设我有一个多维张量t0
:
bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
它的形状是:torch.Size([4, 10, 16])
我有另一个张量labels
,它是seq
维中5个随机指数的一批:
labels = torch.randint(0, seq, size=[bs, sample])
这就是形状torch.Size([4, 5])
。这用于索引t0
的seq
维度
我想做的是使用labels
张量在批处理维度上循环进行聚集。我的暴力解决方案是:
t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]
导致张量t1
的形状:torch.Size([4, 5, 16])
在pytorch中有没有更惯用的方法
你可以这样做:
或
您可以在这里使用fancy indexing来选择所需的张量部分
本质上,如果预先生成传递访问模式的索引数组,则可以直接使用它们提取张量的某些片段。每个维度的索引数组的形状应与要提取的输出张量或切片的形状相同
注意上面3个张量的形状Broadcasting在张量的前面附加额外的维度,从而从本质上将
k
重塑为[1, 1, v]
形状,这使得所有3个维度都兼容元素操作广播后
(i, j, k)
一起将产生3[bs, sample, v]
形状的数组,这些数组将(按元素)索引原始张量以产生形状[bs, sample, v]
的输出张量t1
相关问题 更多 >
编程相关推荐