将自定义渐变定义为以十为单位的类方法

2024-04-16 08:28:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要将方法定义为自定义渐变,如下所示:

class CustGradClass:

    def __init__(self):
        pass

    @tf.custom_gradient
    def f(self,x):
      fx = x
      def grad(dy):
        return dy * 1
      return fx, grad

我得到以下错误:

ValueError: Attempt to convert a value (<main.CustGradClass object at 0x12ed91710>) with an unsupported type () to a Tensor.

原因是自定义渐变接受函数f(*x),其中x是张量序列。传递的第一个参数是对象本身,即self。在

documentation

f: function f(*x) that returns a tuple (y, grad_fn) where:
x is a sequence of Tensor inputs to the function. y is a Tensor or sequence of Tensor outputs of applying TensorFlow operations in f to x. grad_fn is a function with the signature g(*grad_ys)

我该怎么做?我需要继承一些python tensorflow类吗?在

我使用的是tf版本1.12.0和eager模式。在


Tags: oftoselfreturnistfdefwith
2条回答

在您的示例中,您没有使用任何成员变量,因此您可以将该方法设置为静态方法。如果使用成员变量,则从成员函数调用静态方法,并将成员变量作为参数传递。在

class CustGradClass:

  def __init__(self):
    self.some_var = ...

  @staticmethod
  @tf.custom_gradient
  def _f(x):
    fx = x
    def grad(dy):
      return dy * 1

    return fx, grad

  def f(self):
    return CustGradClass._f(self.some_var)

这是一种可能的简单解决方法:

import tensorflow as tf

class CustGradClass:

    def __init__(self):
        self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))

    @staticmethod
    def _f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.constant(1.0)
    c = CustGradClass()
    y = c.f(x)
    print(tf.gradients(y, x))
    # [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]

编辑:

如果您想在不同的类上多次这样做,或者只是想要一个更可重用的解决方案,您可以使用像这样的decorator,例如:

^{pr2}$

然后你就可以:

class CustGradClass:

    def __init__(self):
        pass

    @tf_custom_gradient_method
    def f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

    @tf_custom_gradient_method
    def f2(self, x):
        fx = x * 2
        def grad(dy):
            return dy * 2
        return fx, grad

相关问题 更多 >