在计算VGG Perceptual loss时,虽然我没有看到,但我觉得可以将GT图像的VGG特性的计算包装在torch.no_grad()
中
所以基本上我觉得以下几点是可以的
with torch.no_grad():
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
# Now compute L1 loss
或者应该用,
gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)
在这两种方法中requires_grad
为VGG参数设置False
,VGG置于eval()
模式
第一种方法将节省大量GPU资源,并且由于不需要通过GT图像进行反向传播,因此感觉应该与第二种方法在数值上相等。但在大多数实现中,我发现第二种方法用于计算VGG感知损失
那么,在PyTorch中实现VGG感知损失应该采用哪种方法呢
第一种方式:
尽管VGG处于
eval
模式,并且其参数保持不变,但您仍然需要通过它将渐变从特性上的丢失传播到输出nw_op
。 但是,没有理由计算这些梯度w.r.tgt
相关问题 更多 >
编程相关推荐