在不使用retain_graph=True的情况下反向传播两个具有不同损失的网络?

0 投票
2 回答
22 浏览
提问于 2025-04-14 16:58

我有两个网络,它们是串联在一起的,执行一个比较复杂的计算。

这两个网络的损失目标是一样的,不过我想在第二个网络的损失上加一个掩码。

我该怎么做才能不使用retain_graph=True呢?

# tenc          - network1
# unet          - network2

# the work flow is input->tenc->hidden_state->unet->output


params = []
params.append([{'params': tenc.parameters(), 'weight_decay': 1e-3, 'lr': 1e-07}])
params.append([{'params': unet.parameters(), 'weight_decay': 1e-2, 'lr': 1e-06}])
optimizer = torch.optim.AdamW(itertools.chain(*params), lr=1, betas=(0.9, 0.99), eps=1e-07, fused = True, foreach=False)
scheduler = custom_scheduler(optimizer=optimizer, warmup_steps= 30, exponent= 5, random=False)
scaler = torch.cuda.amp.GradScaler() 


loss = torch.nn.functional.mse_loss(model_pred, target, reduction='none')
loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()

scaler.scale(loss_tenc).backward(retain_graph=True)
scaler.scale(loss_unet).backward()
scaler.unscale_(optimizer)

scaler.step(optimizer)
scaler.update()

scheduler.step()
optimizer.zero_grad(set_to_none=True)

loss_tenc 这个损失只应该优化 tenc 的参数,而 loss_unet 只优化 unet 的参数。如果需要的话,我可能得用两个不同的优化器,但为了简单起见,我在这里把它们放在一起了。

2 个回答

0

考虑到这两个组件都连接到 model_pred,你可以通过将两个损失值加在一起,进行一次反向传播:

loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()

scaler.scale(loss_tenc + loss_unet).backward()
0

你为什么不想用 retain_graph=True 呢?其实你的问题可以通过不同的优化器来解决,并且在第二次调用 .backward() 之前,先对 unet 使用 .zero_grad()。这样做的步骤应该是:

ten_c.zero_grad() # model can call .zero_grad() 
loss_tenc.backward(retain_graph=True)
ten_c_optimizer.step() # or equivalence with scaler.

unet.zero_grad() # model can call .zero_grad() 
loss_unet.backward()
unet_optimizer.step() # or equivalence with scaler.

# zero_grad() both for safety
ten_c.zero_grad()
unet.zero_grad()

如果你只想用一个优化器,那么先计算 ten_c 的梯度,然后冻结 ten_c 的权重,再对 unet 使用 zero_grad(),然后再调用第二次 .backward()

ten_c.zero_grad() # model can call .zero_grad() 
loss_tenc.backward(retain_graph=True)
for p in ten_c.parameters():
    p.requires_grad = False # freeze the weight so the gradient will not be updated on the second `.backward()`

unet.zero_grad() # model can call .zero_grad() 
loss_unet.backward() # Only calculate gradient for unet as ten_c have been freeze

for p in ten_c.parameters():
    p.requires_grad = True

optimizer.step() # or equivalence with scaler.


# zero_grad() both for safety
ten_c.zero_grad()
unet.zero_grad()

撰写回答