TensorFlow中两幅图像的梯度差损失

0 投票
1 回答
43 浏览
提问于 2025-04-14 17:55

我该如何在TensorFlow中实现两个图像的梯度差损失(Gradient Difference Loss),就像在PyTorch中那样做的?这里有个链接可以参考:https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py。谢谢!

1 个回答

0

我在这个链接找到了解决方案:https://github.com/jonasrothfuss/DeepEpisodicMemory/blob/master/models/loss_functions.py

def gradient_difference_loss(true, pred, alpha=2.0):
  """
  computes gradient difference loss of two images
  :param ground truth image: Tensor of shape (batch_size, frame_height, frame_width, num_channels)
  :param predicted image: Tensor of shape (batch_size, frame_height, frame_width, num_channels)
  :param alpha parameter of the used l-norm
  """
  #tf.assert_equal(tf.shape(true), tf.shape(pred))
  # vertical
  true_pred_diff_vert = tf.pow(tf.abs(difference_gradient(true, vertical=True) - difference_gradient(pred, vertical=True)), alpha)
  # horizontal
  true_pred_diff_hor = tf.pow(tf.abs(difference_gradient(true, vertical=False) - difference_gradient(pred, vertical=False)), alpha)
  # normalization over all dimensions
  return (tf.reduce_mean(true_pred_diff_vert) + tf.reduce_mean(true_pred_diff_hor)) / tf.to_float(2)


def difference_gradient(image, vertical=True):
  """
  :param image: Tensor of shape (batch_size, frame_height, frame_width, num_channels)
  :param vertical: boolean that indicates whether vertical or horizontal pixel gradient shall be computed
  :return: difference_gradient -> Tenor of shape (:, frame_height-1, frame_width, :) if vertical and (:, frame_height, frame_width-1, :) else
  """
  s = tf.shape(image)
  if vertical:
    return tf.abs(image[:, 0:s[1] - 1, :, :] - image[:, 1:s[1], :, :])
  else:
    return tf.abs(image[:, :, 0:s[2]-1,:] - image[:, :, 1:s[2], :])

我只需要调整下面这行原始代码:

return (tf.reduce_mean(true_pred_diff_vert) + tf.reduce_mean(true_pred_diff_hor)) / tf.to_float(2)

改成:

return (tf.reduce_mean(true_pred_diff_vert) + tf.reduce_mean(true_pred_diff_hor)) / tf.cast(2, tf.float32)

这是因为我使用的TensorFlow版本比较特殊(这个信息是从https://github.com/google/tangent/issues/95#issuecomment-562551139获得的帮助)。

撰写回答