我希望有一个自定义渐变,如下所示,并进行取整,然后在一个简单的操作中获得自定义渐变。当我运行下面的代码时,它不会使用掩码的tf.cast输出任何值>;=0.5美元。这应该很简单……我在这里遗漏了什么?请注意,我希望调用softhardthresh函数的输出不会以任何方式影响渐变计算
@tf.custom_gradient
def softhardthresh(x):
mask = tf.nn.sigmoid(x)
def grad(dy):
return dy * (mask * (1 - mask))
return tf.cast(mask >= 0.5, tf.float32), grad
inputs = tf.keras.Input(shape=(32,))
x = layers.Dense(32, activation=actfun)(inputs)
output = layers.Dense(32)(x) ## hard threshold
output = softhardthresh(output)
mask_ann = tf.keras.Model(inputs, output, name='mask_ann')
with tf.GradientTape(persistent=False) as tape:
x = tf.random.uniform(shape=[10,32])
tape.watch(x)
y = mask_ann(x)
grads = tape.gradient(y, mask_ann.trainable_variables)
目前没有回答
相关问题 更多 >
编程相关推荐