试图创建自定义keras损失函数来解释“类距离”

2024-05-15 11:07:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试编写一个ResNet,它获取一个数据图像,并尝试从该事件中猜测一个连续变量(在本例中是能量)。为了做到这一点,我有点捏造类的定义:网络输出100个“类”的向量,这些“类”实际上是0-1000个能量单位(MeV)的等距能量箱,因此每个能量箱的宽度为10 MeV。所以,我想写一个自定义的损失函数,这样如果网络在一个离正确的能量箱更远的能量箱中输出一个高概率,那么惩罚会比在高概率的能量箱离正确的能量箱更近的情况下更严厉,即使高概率的能量箱在这两种情况下都是“错误的”。这是我为单幅图像提出的第一个传递损失函数:

Loss

在这里,带帽的i表示真正的能量箱。我在Keras背景下实现这个有点困难。具体来说,我找不到一种方法来实现只使用keras.backend文件函数,特别是当训练批大小未知时。在numpy/for循环语言中,如果M是已知的批大小,bin是网络输出中的能量bin(aka class)的数量,那么我会将损失写为:

def customLoss(yTrue,yPred):
    # yTrue and yPred have dimension (M,bins)
    trueBinVec = np.argmax(yTrue,axis=-1) # dimension (M)
    trueBinArr = np.ones((M,bins))
    binsArr = np.ones((M,bins))
    for m in range(M): 
        trueBinArr[m] *= trueBinVec[m]
        binsArr[m] = np.arange(bins)
    binDifTensor = (trueBinsArr - binsArr)**2
    finalTensor = np.multiply((yTrue-yPred)**2,binDifTensor)
    return np.mean(finalTensor,axis=-1) # this would return a tensor of shape (M) containing the loss for each image in the batch

问题是,我事先没有访问批大小的权限,因此很难创建binDifTensor,这样我就可以在keras后端用传统的均方误差张量按元素进行乘积。如果您对如何使用keras.backend文件功能!谢谢


Tags: 函数图像网络fornp概率keras能量