如何在pytorch中同时运行多个分支?

2024-05-23 20:13:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在pytorch中建立一个包含多个分支的网络。但是我怎样才能并行运行多个分支,而不是逐个运行它们呢

不像tensorflow或keras,pytorch使用动态图,所以我不能预先定义并发处理

我查找了一些类似Pytork网络的官方工具,如InceptionNet,结果发现,Pytork连续运行,有多个分支

inception.py

def _forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch5x5 = self.branch5x5_1(x)
    branch5x5 = self.branch5x5_2(branch5x5)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
    return outputs

四个分支一个接一个地运行,首先是branch1x1,然后是branch5x5,然后是branch3x3dbl,branch_池。然后,输出存储其结果,稍后将连接这些结果

这不是在浪费性能吗?我们该如何应对呢?


Tags: self网络branch定义tensorflow分支pytorchoutputs
1条回答
网友
1楼 · 发布于 2024-05-23 20:13:57

一般来说,只要使用pytorch提供的函数,就不必关心网络执行的性能

正如注释中指出的,对gpu的所有调用都是异步的。只要调用不依赖于数据,它就会被执行。所以在你的例子中,你有多个分支。Pytorch将安排所有操作,并根据数据相关性执行这些操作。由于分支之间不共享数据,因此它们将并行执行

那么你的情况呢

branch3x3dbl = self.branch3x3dbl_1(x)
branch1x1 = self.branch1x1(x)
branch3x3dbl = self.branch3x3dbl_1(x)

可能或多或少同时执行。以下所有层的情况都是一样的

相关问题 更多 >