定制损耗输出(单位:Keras)

2024-04-20 15:26:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我知道在Keras中有很多关于自定义丢失函数的问题,但是我在google上搜索了3个小时后仍然无法回答这个问题。在

这是我的问题的一个非常简单的例子。我意识到这个例子是没有意义的,但我提供它是为了简单,我显然需要实现一些更复杂的东西。在

from keras.backend import binary_crossentropy
from keras.backend import mean
def custom_loss(y_true, y_pred):

    zeros = tf.zeros_like(y_true)
    index_of_zeros = tf.where(tf.equal(zeros, y_true))
    ones = tf.ones_like(y_true)
    index_of_ones = tf.where(tf.equal(ones, y_true))

    zero = tf.gather(y_pred, index_of_zeros)
    one = tf.gather(y_pred, index_of_ones)

    loss_0 = binary_crossentropy(tf.zeros_like(zero), zero)
    loss_1 = binary_crossentropy(tf.ones_like(one), one)

    return mean(tf.concat([loss_0, loss_1], axis=0))

我不明白为什么在两类数据集上用上面的损失函数训练网络不能产生与使用内置的binary-crossentropy损失函数训练相同的结果。 谢谢您!在

编辑:我编辑了代码片段,以包括下面的注释中的平均值。但我还是有同样的行为。在


Tags: of函数trueindextfoneszerosone
1条回答
网友
1楼 · 发布于 2024-04-20 15:26:19

我终于明白了。当形状为“未知”时,tf.where函数的行为非常不同。 要修复上面的代码段,只需在声明函数后插入以下行:

y_pred = tf.reshape(y_pred, [-1])
y_true = tf.reshape(y_true, [-1])

相关问题 更多 >