计算感知损失VGG特征的正确方法

2024-03-28 12:09:10 发布

您现在位置:Python中文网/ 问答频道 /正文

在计算VGG Perceptual loss时,虽然我没有看到,但我觉得可以将GT图像的VGG特性的计算包装在torch.no_grad()

所以基本上我觉得以下几点是可以的

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

# Now compute L1 loss

或者应该用,

gt_vgg_features = self.vgg_features(gt)
nw_op_vgg_features = self.vgg_features(nw_op)

在这两种方法中requires_grad为VGG参数设置False,VGG置于eval()模式

第一种方法将节省大量GPU资源,并且由于不需要通过GT图像进行反向传播,因此感觉应该与第二种方法在数值上相等。但在大多数实现中,我发现第二种方法用于计算VGG感知损失

那么,在PyTorch中实现VGG感知损失应该采用哪种方法呢


1条回答
网友
1楼 · 发布于 2024-03-28 12:09:10

第一种方式:

with torch.no_grad():
    gt_vgg_features = self.vgg_features(gt)

nw_op_vgg_features = self.vgg_features(nw_op)

尽管VGG处于eval模式,并且其参数保持不变,但您仍然需要通过它将渐变从特性上的丢失传播到输出nw_op。 但是,没有理由计算这些梯度w.r.tgt

相关问题 更多 >