PyTorch:PyTorch中的numpy.linalg.multi_dot()等价物是什么

2024-03-29 12:32:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在PyTorch中执行多个矩阵的矩阵乘法,并且想知道PyTorch中numpy.linalg.multi_dot()的等价物是什么

如果没有,那么在PyTorch中,下一个最好的方法(在速度和内存方面)是什么

代码:

import numpy as np
import torch

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)

results = np.linalg.multi_dot(A, B, C)

A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)

# What is the PyTorch equivalent of np.linalg.multi_dot()?

非常感谢


Tags: 方法importnumpynp矩阵randomtorchpytorch
1条回答
网友
1楼 · 发布于 2024-03-29 12:32:11

~~看起来可以将张量发送到多点~

看起来numpy实现将所有内容都强制转换到numpy阵列中。如果你的张量在cpu上并且分离,这应该可以工作。否则,到numpy的转换将失败

因此,总的来说,很可能没有其他选择。我认为最好的方法是使用multi_dot实现,例如from here for numpy v1.19.0并将其调整为处理张量/跳过转换为numpy。考虑到类似的接口和代码的简单性,我认为这应该非常简单

相关问题 更多 >