有没有类似的函数火炬.argmax哪个能真正保持原始数据的维度?

2024-03-29 08:50:19 发布

您现在位置:Python中文网/ 问答频道 /正文

例如, 代码是

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)

input

^{pr2}$

并且result

tensor([[ 0,  2,  1,  2,  2]])

但是,我想要这样的结果

tensor([[ 1,  0,  0,  0,  0],
        [ 0,  0,  1,  0,  0],
        [ 0,  1,  0,  1,  1]])

Tags: 代码trueinputtorchresulttensordimrandn
1条回答
网友
1楼 · 发布于 2024-03-29 08:50:19

最后,我解决了。但这种解决方案可能并不有效。 代码如下:

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)
result_0 = result == 0
result_1 = result == 1
result_2 = result == 2
result = torch.cat((result_0, result_1, result_2), 0)

相关问题 更多 >