Pythorch等效于index_Uadd_U,它取最大值

2024-04-23 18:15:06 发布

您现在位置:Python中文网/ 问答频道 /正文

在PyTorch中,张量的index_add_方法使用提供的索引张量进行求和:

idx = torch.LongTensor([0,0,0,0,1,1])
child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
parent = torch.FloatTensor([0, 0])
parent.index_add_(0, idx, child)

前四个子值和为parent[0],后两个子值为parent[1],因此结果是tensor([ 19., 9.])

但是,我需要做index_max_,这在API中不存在。有没有一种有效的方法(不必循环或分配更多内存)?一个(坏的)循环解决方案是:

^{pr2}$

这将产生tensor([ 10., 8.])所需的结果,但速度非常慢。在


Tags: 方法内存addapichildindextorchpytorch
1条回答
网友
1楼 · 发布于 2024-04-23 18:15:06

利用指数的解决方案:

def index_max(child, idx, num_partitions): 
    # Building a num_partition x num_samples matrix `idx_tiled`:
    partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
    partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
    idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
    idx_tiled = (idx_tiled == partition_idx).float()
    # i.e. idx_tiled[i,j] == 1 if idx[j] == i, else 0

    parent = idx_tiled * child
    parent, _ = torch.max(parent, dim=1)
    return parent

标杆管理:

^{pr2}$

相关问题 更多 >