擅长:python、mysql、java
<p>这是一种可能的简单解决方法:</p>
<pre><code>import tensorflow as tf
class CustGradClass:
def __init__(self):
self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))
@staticmethod
def _f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.constant(1.0)
c = CustGradClass()
y = c.f(x)
print(tf.gradients(y, x))
# [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]
</code></pre>
<p>编辑:</p>
<p>如果您想在不同的类上多次这样做,或者只是想要一个更可重用的解决方案,您可以使用像这样的decorator,例如:</p>
^{pr2}$
<p>然后你就可以:</p>
<pre><code>class CustGradClass:
def __init__(self):
pass
@tf_custom_gradient_method
def f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
@tf_custom_gradient_method
def f2(self, x):
fx = x * 2
def grad(dy):
return dy * 2
return fx, grad
</code></pre>