如何在Pytorch1.1&DistributedDataParallel()中计算米数?

2024-06-07 07:06:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我想同时使用模型并行和数据并行,并且已经阅读了许多官方网站上的文档和教程。 我面临的一个令人困惑的问题是如何在每个过程中收集各种仪表值

问题1:在official tutorial中,他们只记录每个过程中的米值。 但在我的代码中,我在每个过程中打印损失值,它们是不同的。所以,我觉得其他仪表的数值也不一样。 那教程错了吗?在我看来,正确的方法应该是先同步loss、acc和其他仪表,然后所有进程保持相同的值,然后我只需要在一个进程中打印仪表信息

问题2:在official tutorial中,他们说“DistributedDataParallel模块还处理世界各地梯度的平均值,因此我们不必在训练步骤中显式平均梯度”。 但是,由于问题1,API是否真的像教程所说的那样工作?因为每个进程都有不同的损失值,尽管它们从相同的初始权重开始,每个进程中的模型权重是否会在不同的方向上优化?


Tags: 数据代码文档模型进程过程记录仪表