我试图在PyTorch中执行多个矩阵的矩阵乘法,并且想知道PyTorch中numpy.linalg.multi_dot()
的等价物是什么
如果没有,那么在PyTorch中,下一个最好的方法(在速度和内存方面)是什么
代码:
import numpy as np
import torch
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)
results = np.linalg.multi_dot(A, B, C)
A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)
# What is the PyTorch equivalent of np.linalg.multi_dot()?
非常感谢
~~看起来可以将张量发送到多点~
看起来numpy实现将所有内容都强制转换到numpy阵列中。如果你的张量在cpu上并且分离,这应该可以工作。否则,到numpy的转换将失败
因此,总的来说,很可能没有其他选择。我认为最好的方法是使用
multi_dot
实现,例如from here for numpy v1.19.0并将其调整为处理张量/跳过转换为numpy。考虑到类似的接口和代码的简单性,我认为这应该非常简单相关问题 更多 >
编程相关推荐