在梯度计算中,tensorflow如何处理不可微节点?

2024-04-19 11:11:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我了解自动微分的概念,但找不到任何解释,张量流是如何计算不可微函数的误差梯度的,例如在我的损失函数中tf.where,或在我的图中tf.cond。它工作得很好,但是我想了解张量流是如何通过这些节点反向传播误差的,因为没有公式来计算它们的梯度。在


Tags: 函数概念节点tfwhere公式误差梯度
1条回答
网友
1楼 · 发布于 2024-04-19 11:11:21

tf.where的情况下,有一个具有三个输入的函数,条件C,true上的值T,false上的值{},以及一个输出Out。渐变接收一个值,必须返回三个值。目前,没有为这个条件计算梯度(这很难理解),所以您只需要对T和{}进行梯度计算。假设输入和输出都是向量,假设C[0]是{}。那么Out[0]来自{},它的梯度应该向后传播。另一方面,F[0]会被丢弃,所以它的梯度应该为零。如果Out[1]是{},那么{}的梯度应该传播而不是{}。因此,简而言之,对于T,你应该传播给定的梯度,其中C是{},在{}处使其为零,F则相反。如果你看一下the implementation of the gradient of ^{} (^{} operation),它确实做到了:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, array_ops.where(c, grad, zeros), array_ops.where(
      c, zeros, grad))

注:输入值本身不用于计算,这将由产生这些输入的操作的梯度来完成。对于tf.condthe code is a bit more complicated,因为相同的操作(Merge)在不同的上下文中使用,而且tf.cond也在里面使用Switch操作。但是想法是一样的。本质上,Switch操作用于每个输入,因此被激活的输入(第一个如果条件是True,第二个则为其他)获得接收的梯度,而另一个输入获得“关闭”的梯度(如None),并且不会进一步传播回去。在

相关问题 更多 >