如何合并PyTorch矩阵中相同元素的行?
举个例子,假设有一个输入的矩阵 matrix = torch.tensor([[1, 2], [2, 3],[4, 5]])
,我们希望得到的输出矩阵是 torch.tensor([[1, 3],[4, 5]])
,因为第一行和第二行有一个相同的元素2。
那怎么才能做到这一点呢?谢谢!
1 个回答
2
你可以找出那些在第0列的值和第1列的前一个值不一样的行,然后把其他的行去掉:
mask = matrix[1:,0]!=matrix[:-1,1]
# tensor([False, True])
true = torch.tensor([True])
out = torch.column_stack([matrix[torch.cat((true, mask)), 0],
matrix[torch.cat((mask, true)), 1],
])
变体:
mask = torch.cat((torch.tensor([True]),
matrix[1:,0]!=matrix[:-1,1]))
# tensor([ True, False, True])
out = torch.column_stack([matrix[mask, 0],
matrix[mask.roll(-1), 1],
])
输出:
tensor([[1, 3],
[4, 5]])
更复杂的例子:
# input # values to keep (X)
tensor([[1, 2], # X -
[2, 3], # - -
[3, 4], # - X
[0, 0], # X X
[5, 6], # X -
[6, 9]]) # - X
# output
tensor([[1, 4],
[0, 0],
[5, 9]])