tf.grad()返回零矩阵

2024-06-16 16:29:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我在MNIST数据集上训练了一个CNN,并计划用它来推断tf.js

然而,当检查模型相对于其中一个测试图像的梯度时,我注意到 我一直得到一个零矩阵。如果这是一个重要的观察结果 零矩阵具有输入图像的正确形状

function getGradient(img, yTrue) {

    yTrue = tf.oneHot(tf.tensor1d([yTrue], 'int32'), 10);

    function f(x) {
        return tf.metrics.categoricalCrossentropy(yTrue, model.predict(x));    
    }
    var g = tf.grad(f);

    var grad = g(img);
    return grad;
}

Python等价物允许我成功地计算梯度

loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

//already used one-hot encoding to convert the input_label
def getGradient(input_image, input_label):
  with tf.GradientTape() as tape:
    tape.watch(input_image)
    prediction = model(input_image)
    loss = loss_object(input_label, prediction)

  gradient = tape.gradient(loss, input_image)

  return gradient 


Tags: 图像imageinputreturntffunction矩阵label