如何合并PyTorch矩阵中相同元素的行?

1 投票
1 回答
41 浏览
提问于 2025-04-13 16:53

举个例子,假设有一个输入的矩阵 matrix = torch.tensor([[1, 2], [2, 3],[4, 5]]),我们希望得到的输出矩阵是 torch.tensor([[1, 3],[4, 5]]),因为第一行和第二行有一个相同的元素2。

那怎么才能做到这一点呢?谢谢!

1 个回答

2

你可以找出那些在第0列的值和第1列的前一个值不一样的行,然后把其他的行去掉:

mask = matrix[1:,0]!=matrix[:-1,1]
# tensor([False,  True])

true = torch.tensor([True])

out = torch.column_stack([matrix[torch.cat((true, mask)), 0],
                          matrix[torch.cat((mask, true)), 1],
                         ])

变体:

mask = torch.cat((torch.tensor([True]),
                  matrix[1:,0]!=matrix[:-1,1]))
# tensor([ True, False,  True])

out = torch.column_stack([matrix[mask, 0],
                          matrix[mask.roll(-1), 1],
                         ])

输出:

tensor([[1, 3],
        [4, 5]])

更复杂的例子:

# input             # values to keep (X)
tensor([[1, 2],     #    X   -
        [2, 3],     #    -   -
        [3, 4],     #    -   X
        [0, 0],     #    X   X
        [5, 6],     #    X   -
        [6, 9]])    #    -   X

# output
tensor([[1, 4],
        [0, 0],
        [5, 9]])

撰写回答