在Tensorflow插件中有两个提到的三元组丢失一个是基类tfa.losses.triplet_semihard_loss
,另一个是tfa.losses.TripletSemiHardLoss
,它是由用户初始化的子类,然后隐式地调用基类。在属于子类的这段代码中:
def __init__(self, margin=1.0, name=None):
super(TripletSemiHardLoss, self).__init__(
name=name, reduction=tf.keras.losses.Reduction.NONE)
self.margin = margin
def call(self, y_true, y_pred):
return triplet_semihard_loss(y_true, y_pred, self.margin)
我不明白call
方法是怎么回事,它返回的基类函数给出了y_true
和y_pred
数组,但它们究竟是从哪里来的?根据Tensorflow文档指南,子类在modelcompile
语句中初始化为:
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss())
然后将模型拟合为:
history = model.fit(
train_dataset,
epochs=5)
train_dataset
结构是一个包含嵌入数据和相应整数标签的元组,但是子类如何认识到这是要操作的数据呢?call
方法也是隐式调用的吗?你知道吗
在调用类的实例时调用
__call__
。y_true
和y_pred
分别包含模型预测的真实标签和标签。张量流(克拉斯特遣部队)内部转换给y_true
的标签,如here所示,并使用model.fit()
对数据进行训练。所有
tf.keras
损失都以这种形式实现,即具有两个参数y_true
和y_pred
的函数,如here。你知道吗相关问题 更多 >
编程相关推荐