pytorch:如何进行分层乘法?

2024-04-19 20:54:44 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个包含五个2x2矩阵的张量-形状(1,5,2,2),还有一个包含五个元素的张量-形状([5])。我想将每个2x2矩阵(在前一个张量中)乘以相应的值(在后一个张量中)。合成张量应为形状(1,5,2,2)。怎么做

运行此代码时出现以下错误

a = torch.rand(1,5,2,2)
print(a.shape)
b = torch.rand(5)
print(b.shape)
mul = a*b

RuntimeError: The size of tensor a (2) must match the size of tensor b (5) at non-singleton dimension 3

Tags: ofthe代码元素size错误矩阵torch
1条回答
网友
1楼 · 发布于 2024-04-19 20:54:44

您可以使用a * btorch.mul(a, b)但是必须在乘法之前和之后使用permute(),以获得兼容的形状:

import torch
a = torch.ones(1,5,2,2)
b = torch.rand(5)
a.shape # torch.Size([1, 5, 2, 2])
b.shape # torch.Size([5])

c = (a.permute(0,2,3,1) * b).permute(0,3,1,2)
c.shape # torch.Size([1, 5, 2, 2])

# OR #
c = torch.mul(a.permute(0,2,3,1), b).permute(0,3,1,2)
c.shape # torch.Size([1, 5, 2, 2])

permute()函数按照其参数的顺序转换维度。也就是说,a.permute(0,2,3,1)的形状是torch.Size([1,2,2,5]),它适合矩阵乘法的b(torch.Size([5]),因为a的最后一个维度等于b的第一个维度。完成乘法后,我们使用permute()再次将其转置到。所需的焊炬形状。通过排列(0,3,1,2)的尺寸([1,5,2,2])

您可以在docs中阅读有关permute()的内容。但它使用的参数将[1,5,2,2]的当前形状编号为0到3,并在插入参数时进行排列,这意味着a.permute(0,2,3,1)它将保持第一个维度的位置,因为第一个参数是0,第二个维度它将移动到第四个维度,因为索引1是第四个参数。第三和第四维度将移动到第二和第三维度,因为第二和第三维度的索引位于第二和第三位。记住,当谈到第四维时,它作为参数的表示是3(不是4)

编辑 例如,如果希望按元素对形状[32,5,2,2]和[32,5]的张量进行乘法,使每个2x2矩阵都乘以相应的值,则可以将维度重新排列为[2,2,32,5]乘以permute(2,3,0,1),然后执行a * b的乘法,然后再次返回到原始形状。这里的关键是,第一个矩阵的最后n维需要与第二个矩阵的第一n维对齐。在我们的例子中n=2

希望有帮助

相关问题 更多 >