具有自定义损失函数的Tensorflow nan bug

2024-04-23 22:35:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我在调试tensorflow代码以找到nan值的来源时遇到了一些问题。我修改了pix2pix代码,以便l1 loss现在正在计算像素之间的cie94颜色距离:

 with tf.name_scope("generator_loss"):
        # predict_fake => 1
        # abs(targets - outputs) => 0
        kl = 1.0
        k1 = 0.045
        k2 = 0.015
        targetslab = rgb_to_lab(targets)
        outputslab = rgb_to_lab(outputs)

        targets_l, targets_a, targets_b = preprocess_lab(targetslab)
        outputs_l, outputs_a, outputs_b = preprocess_lab(outputslab)
        euclidean = tf.sqrt(tf.square(targets_l-outputs_l) + tf.square(targets_a-outputs_a) + tf.square(targets_b-outputs_b))
        delta_l = tf.abs(targets_l - outputs_l)
        c1 = tf.sqrt(tf.square(targets_a) + tf.square(targets_b))
        c2 = tf.sqrt(tf.square(outputs_a) + tf.square(outputs_b))
        delta_c = tf.abs(c1 - c2)
        delta_h = tf.sqrt(tf.square(euclidean) - tf.square(delta_l) - tf.square(delta_c))
        #delta_a = tf.abs(targets_a - outputs_a)
        #delta_b = tf.abs(targets_b - outputs_b)
        sl = 1.0
        sc = 1.0 + tf.multiply(k1, c1)
        sh = 1.0 + tf.multiply(k2, c1)
        gen_loss_L1 = tf.reduce_mean(tf.sqrt(tf.square(tf.truediv(delta_l, tf.multiply(kl, sl))
                                                       + tf.truediv(delta_c, sc) + tf.truediv(delta_h, sh))))
        gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
        #gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
        gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight

从l1更改为自定义的损失函数后,我在输出中不断获取nan值。以前一切都很好,所以我在目标/输出中绝对没有nan值。在


Tags: l1tflababssqrtnanoutputsmultiply