Pytorch parameters()在收集到列表或保存在gen中时表现不同

2024-06-16 10:57:37 发布

您现在位置:Python中文网/ 问答频道 /正文

TL;DR—使用生成器失败,使用列表成功。为什么?

我正在尝试手动更改模型的参数,如下所示:

(第一个代码,工作)

       delta = r_t + gamma * expected_reward_from_t1.data - expected_reward_from_t.data

        negative_expected_reward_from_t = -expected_reward_from_t
        self.critic_optimizer.zero_grad()
        negative_expected_reward_from_t.backward()

        for i, p in enumerate(self.critic_nn.parameters()):
             if not p.requires_grad:
                 continue
             p.grad[:] = delta.squeeze() * discount * p.grad

        self.critic_optimizer.step()

它似乎100%地收敛于正确的结果

但是,

当试图使用这样的函数时:

(第二个代码,失败)

def _update_grads(self,delta, discount):
    params = self.critic_nn.parameters()
    for i, p in enumerate(params):
        if not p.requires_grad:
            continue
        p.grad[:] = delta.squeeze() * discount * p.grad

然后呢

       delta = r_t + gamma * expected_reward_from_t1.data - expected_reward_from_t.data

        negative_expected_reward_from_t = -expected_reward_from_t
        self.critic_optimizer.zero_grad()
        negative_expected_reward_from_t.backward()
        self._update_grads(delta=delta,
                           discount=discount)

        self.critic_optimizer.step()

我唯一做的就是把self.critic_nn.parameters()放入一个临时局部变量params

现在网络没有聚合。

(第三个代码,同样有效)

在方法_update_gradsparams = self.critic_nn.parameters()中替换为params = list(self.critic_nn.parameters())

现在,收敛又恢复了


这似乎是一个参考问题,在PyTorch中,我并不完全理解。我似乎不完全理解parameters()返回的内容


问题: 为什么第一个和第三个代码有效,而第二个代码无效?


Tags: 代码fromselfdatadiscountnnparamsoptimizer