Keras/tensorflow:无法调用tf.打印()损失函数

2024-04-19 07:38:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我从this得到了以下损失函数:

def weightedLoss(originalLossFunc, weightsList):

    def lossFunc(true, pred):

        axis = -1 #if channels last 
        #axis=  1 #if channels first


        #argmax returns the index of the element with the greatest value
        #done in the class axis, it returns the class index    
        classSelectors = K.argmax(true, axis=axis) 

        #considering weights are ordered by class, for each class
        #true(1) if the class index is equal to the weight index   
        classSelectors = [K.equal(i, classSelectors) for i in range(len(weightsList))]

        #casting boolean to float for calculations  
        #each tensor in the list contains 1 where ground true class is equal to its index 
        #if you sum all these, you will get a tensor full of ones. 
        classSelectors = [K.cast(x, K.floatx()) for x in classSelectors]

        #for each of the selections above, multiply their respective weight
        weights = [sel * w for sel,w in zip(classSelectors, weightsList)] 

        #sums all the selections
        #result is a tensor with the respective weight for each element in predictions
        weightMultiplier = weights[0]
        for i in range(1, len(weights)):
            weightMultiplier = weightMultiplier + weights[i]


        #make sure your originalLossFunc only collapses the class axis
        #you need the other axes intact to multiply the weights tensor
        loss = originalLossFunc(true,pred) 
        weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage")
        loss = loss * weightMultiplier
        #weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage") ---location 2
        return loss
    return lossFunc

现在在这个函数中,我有一个print语句来打印权重向量。在它的当前位置,网络不会打印任何东西,尽管我认为这会导致网络将它包含在计算图中。然后,我把它往下移了一行,试了试,但也没用。我做错什么了?我什么时候都不会出错。你知道吗


Tags: thetointrueforindexifclass