如何在散点更新中启用渐变流?

2024-05-08 22:31:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图通过从所有可能的固定大小的窗口(例如5x5)获取数据,在训练循环中计算图像的局部方差图。为了对这个操作进行矢量化,我考虑在训练循环中使用scatter_update/scatter_nd_update使用类似于this的操作来扩展原始图像。这个操作的实质是将原始张量中的每个元素映射到新张量中的潜在多个位置,这些位置在训练循环中计算。在

但是,scatter_update不允许渐变传播,我为scatter_update创建一个简单的自定义渐变的尝试没有成功。在

@tf.RegisterGradient("CustomGrad")
def _clip_grad(unused_op, grad):
  return tf.constant(5., dtype=tf.float32, shape=(1)) # tf.clip_by_value(grad, -0.1, 0.1)

x = tf.Variable([3.0], dtype=tf.float32)
y = tf.get_variable('y', shape=(1), dtype=tf.float32)

g = tf.get_default_graph()
with g.gradient_override_map({"ScatterNdUpdate1": "CustomGrad"}):
    output = tf.scatter_nd_update(y, [[0]], x, name="ScatterNdUpdate1")
grad_custom = tf.gradients(output, y)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(grad_custom)

运行上面的代码显示grad_custom包含{}。有人知道如何正确地实现一个可以在训练循环中使用的局部方差图吗?解决梯度问题也可以帮助我解决另一个问题。在


Tags: 图像getcliptfcustomupdate局部shape

热门问题