Pytorch:将张量按列拆分

2024-04-19 02:54:58 发布

您现在位置:Python中文网/ 问答频道 /正文

如何按列拆分张量(轴=1)。比如说

"""
input:              result:
tensor([[1, 1],     (tensor([1, 2, 3, 1, 2, 3]), 
        [2, 1],      tensor([1, 1, 2, 2, 3, 3]))
        [3, 2],  
        [1, 2],
        [2, 3],
        [3, 3]])
"""

我提出的解决方案是首先转置输入张量,拆分它,然后展平每个拆分张量。然而,有没有一种更简单、更有效的方法?多谢各位

import torch
x = torch.LongTensor([[1,1],[2,1],[3,2],[1,2],[2,3],[3,3]])
x1, x2 = torch.split(x.T, 1)
x1 = torch.flatten(x1)
x2 = torch.flatten(x2)
x1, x2 # output

Tags: 方法importinputoutputtorchresult解决方案split