Pytorch:如何在张量内连接列表?

2024-04-23 18:10:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个大小为(2, b, h)的张量,我想将其更改为以下大小:(b, 2*h),其中相应的列表是串联的,例如:

a = torch.tensor([[[1, 2, 3], [4, 5, 6], [4, 4, 4]],
                  [[4, 5, 6], [7, 8, 9], [5, 5, 5]]])

我想:

b = tensor([[1, 2, 3, 4, 5, 6],
            [4, 5, 6, 7, 8, 9],
            [4, 4, 4, 5, 5, 5]])

Tags: 列表torchtensor
2条回答

首先使用排列来更改标注顺序,然后使用连续来防止排列张量内的跨步,最后使用视图来重塑张量

b = a.permute(1,0,2).contiguous().view(a.shape[1],-1)

为了连接pytorch中的张量,可以使用torch.cat函数沿选定轴连接张量。 在本例中,您可以执行以下操作:

a = torch.tensor([[[1, 2, 3], [4, 5, 6], [4, 4, 4]],
                  [[4, 5, 6], [7, 8, 9], [5, 5, 5]]])

b = torch.cat((a[0], a[1]), dim=1)

Out:
tensor([[1, 2, 3, 4, 5, 6],
        [4, 5, 6, 7, 8, 9],
        [4, 4, 4, 5, 5, 5]])

相关问题 更多 >