如何在PyTorch中做矩阵乘积

2024-04-20 07:28:15 发布

您现在位置:Python中文网/ 问答频道 /正文

在numpy中,我可以做如下简单的矩阵乘法:

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))

但是,当我尝试使用PyTorch张量时,这不起作用:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())

print(b)
print(b.size())

print(torch.dot(a, b))

此代码引发以下错误:

RuntimeError: inconsistent tensor size at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

你知道怎样在PyTorch中进行矩阵乘法吗?


Tags: 代码numpyviewsize错误矩阵torchpytorch
3条回答

使用torch.mm(a, b)torch.matmul(a, b)
两者都是一样的。

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

还有一个可能是好消息。 那是@运算符。@西蒙H

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

三个结果相同。

相关链接:
Matrix multiplication operator
PEP 465 -- A dedicated infix operator for matrix multiplication

如果你想做一个矩阵(秩2张量)乘法,你可以用四种等效的方法来做:

AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+

有一些微妙之处。从PyTorch documentation

torch.mm does not broadcast. For broadcasting matrix products, see torch.matmul().

例如,不能用torch.mm乘两个一维向量,也不能乘成批矩阵(秩3)。为此,您应该使用更通用的torch.matmul。有关torch.matmul的广播行为的详细列表,请参见documentation

对于元素乘法,您可以简单地执行(如果A和B具有相同的形状)

A * B # element-wise matrix multiplication (Hadamard product)

你在找

torch.mm(a,b)

注意torch.dot()的行为与np.dot()不同。有人讨论过什么是理想的here。具体地说,torch.dot()ab视为一维向量(与它们的原始形状无关),并计算它们的内积。这个错误被抛出,因为这个行为使您的a成为长度为6的向量,而您的b成为长度为2的向量;因此无法计算它们的内积。对于PyTorch中的矩阵乘法,使用torch.mm()。相比之下,Numpy的np.dot()更加灵活;它计算1D数组的内积,并对2D数组执行矩阵乘法。

根据流行的需求,如果两个参数都是2D,则函数torch.matmul执行矩阵乘法,如果两个参数都是1D,则计算它们的点积。对于此类维度的输入,其行为与np.dot相同。它还允许您分批执行广播或matrix x matrixmatrix x vectorvector x vector操作。有关详细信息,请参见其docs

# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])

# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])

相关问题 更多 >