在不使用retain_graph=True的情况下反向传播两个具有不同损失的网络?
我有两个网络,它们是串联在一起的,执行一个比较复杂的计算。
这两个网络的损失目标是一样的,不过我想在第二个网络的损失上加一个掩码。
我该怎么做才能不使用retain_graph=True呢?
# tenc - network1
# unet - network2
# the work flow is input->tenc->hidden_state->unet->output
params = []
params.append([{'params': tenc.parameters(), 'weight_decay': 1e-3, 'lr': 1e-07}])
params.append([{'params': unet.parameters(), 'weight_decay': 1e-2, 'lr': 1e-06}])
optimizer = torch.optim.AdamW(itertools.chain(*params), lr=1, betas=(0.9, 0.99), eps=1e-07, fused = True, foreach=False)
scheduler = custom_scheduler(optimizer=optimizer, warmup_steps= 30, exponent= 5, random=False)
scaler = torch.cuda.amp.GradScaler()
loss = torch.nn.functional.mse_loss(model_pred, target, reduction='none')
loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()
scaler.scale(loss_tenc).backward(retain_graph=True)
scaler.scale(loss_unet).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
loss_tenc 这个损失只应该优化 tenc 的参数,而 loss_unet 只优化 unet 的参数。如果需要的话,我可能得用两个不同的优化器,但为了简单起见,我在这里把它们放在一起了。
2 个回答
0
考虑到这两个组件都连接到 model_pred
,你可以通过将两个损失值加在一起,进行一次反向传播:
loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()
scaler.scale(loss_tenc + loss_unet).backward()
0
你为什么不想用 retain_graph=True
呢?其实你的问题可以通过不同的优化器来解决,并且在第二次调用 .backward()
之前,先对 unet
使用 .zero_grad()
。这样做的步骤应该是:
ten_c.zero_grad() # model can call .zero_grad()
loss_tenc.backward(retain_graph=True)
ten_c_optimizer.step() # or equivalence with scaler.
unet.zero_grad() # model can call .zero_grad()
loss_unet.backward()
unet_optimizer.step() # or equivalence with scaler.
# zero_grad() both for safety
ten_c.zero_grad()
unet.zero_grad()
如果你只想用一个优化器,那么先计算 ten_c
的梯度,然后冻结 ten_c
的权重,再对 unet
使用 zero_grad()
,然后再调用第二次 .backward()
:
ten_c.zero_grad() # model can call .zero_grad()
loss_tenc.backward(retain_graph=True)
for p in ten_c.parameters():
p.requires_grad = False # freeze the weight so the gradient will not be updated on the second `.backward()`
unet.zero_grad() # model can call .zero_grad()
loss_unet.backward() # Only calculate gradient for unet as ten_c have been freeze
for p in ten_c.parameters():
p.requires_grad = True
optimizer.step() # or equivalence with scaler.
# zero_grad() both for safety
ten_c.zero_grad()
unet.zero_grad()