我从this得到了以下损失函数:
def weightedLoss(originalLossFunc, weightsList):
def lossFunc(true, pred):
axis = -1 #if channels last
#axis= 1 #if channels first
#argmax returns the index of the element with the greatest value
#done in the class axis, it returns the class index
classSelectors = K.argmax(true, axis=axis)
#considering weights are ordered by class, for each class
#true(1) if the class index is equal to the weight index
classSelectors = [K.equal(i, classSelectors) for i in range(len(weightsList))]
#casting boolean to float for calculations
#each tensor in the list contains 1 where ground true class is equal to its index
#if you sum all these, you will get a tensor full of ones.
classSelectors = [K.cast(x, K.floatx()) for x in classSelectors]
#for each of the selections above, multiply their respective weight
weights = [sel * w for sel,w in zip(classSelectors, weightsList)]
#sums all the selections
#result is a tensor with the respective weight for each element in predictions
weightMultiplier = weights[0]
for i in range(1, len(weights)):
weightMultiplier = weightMultiplier + weights[i]
#make sure your originalLossFunc only collapses the class axis
#you need the other axes intact to multiply the weights tensor
loss = originalLossFunc(true,pred)
weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage")
loss = loss * weightMultiplier
#weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage") ---location 2
return loss
return lossFunc
现在在这个函数中,我有一个print语句来打印权重向量。在它的当前位置,网络不会打印任何东西,尽管我认为这会导致网络将它包含在计算图中。然后,我把它往下移了一行,试了试,但也没用。我做错什么了?我什么时候都不会出错。你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐