访问自定义损失函数中的特定元素?

2024-03-28 16:12:21 发布

您现在位置:Python中文网/ 问答频道 /正文

在为分类问题定义自定义损失函数时,是否有方法访问y\u true和y\u pred的特定元素?你知道吗

用例:多标签分类问题,如果我预测类5为假阳性,我想额外惩罚模型,即y_true[5] == 0y_pred[5] == 1

我把损失定义为:

def loss(y_true, y_pred):
      wt = 10 if (y_true[5]==0 and y_pred[5]==1) else 1
      return wt * binary_crossentropy(y_true, y_pred)

我还试着检查K.gather(y_true, 5) == 0,但似乎不行。你知道吗

我的批大小是>;1(256),我使用fit_generator——如果这有什么区别的话。谢谢!你知道吗


Tags: 方法函数模型true元素定义def分类
1条回答
网友
1楼 · 发布于 2024-03-28 16:12:21

Is there a way to access particular elements of y_true and y_pred?

Keras张量的索引工作与numpy数组的索引类似。唯一的区别是结果是Keras张量。因此,您应该随后使用Keras操作。你知道吗


损失函数的可能实现

例如,下面是如何实现损失函数:

def loss(y_true, y_pred):
    a = K.equal(y_true[:, 5], 0)
    b = K.greater(y_pred[:, 5], 0.5)
    condition = K.cast(a, 'float') * K.cast(b, 'float')
    wt = 10 * condition + (1 - condition)
    return K.mean(wt[:, None] * K.binary_crossentropy(y_true, y_pred), axis=-1)

注意:未测试。你知道吗

相关问题 更多 >