tf.custom_渐变在存在舍入时不输出任何值

2024-04-29 07:40:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我希望有一个自定义渐变,如下所示,并进行取整,然后在一个简单的操作中获得自定义渐变。当我运行下面的代码时,它不会使用掩码的tf.cast输出任何值>;=0.5美元。这应该很简单……我在这里遗漏了什么?请注意,我希望调用softhardthresh函数的输出不会以任何方式影响渐变计算

@tf.custom_gradient
def softhardthresh(x):
    mask = tf.nn.sigmoid(x)
    def grad(dy):
        return dy * (mask * (1 - mask))
    return tf.cast(mask >= 0.5, tf.float32), grad

inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(32, activation=actfun)(inputs)
output = layers.Dense(32)(x) ## hard threshold
output = softhardthresh(output)
mask_ann = tf.keras.Model(inputs, output, name='mask_ann')

with tf.GradientTape(persistent=False) as tape:
  x = tf.random.uniform(shape=[10,32])
  tape.watch(x)
  y = mask_ann(x)

grads = tape.gradient(y, mask_ann.trainable_variables)

Tags: outputreturntfdefmaskkerasinputsshape