“怎么做?”三重半硬损耗“接到电话了吗?

2024-03-29 09:59:54 发布

您现在位置:Python中文网/ 问答频道 /正文

在Tensorflow插件中有两个提到的三元组丢失一个是基类tfa.losses.triplet_semihard_loss,另一个是tfa.losses.TripletSemiHardLoss,它是由用户初始化的子类,然后隐式地调用基类。在属于子类的这段代码中:

    def __init__(self, margin=1.0, name=None):
        super(TripletSemiHardLoss, self).__init__(
            name=name, reduction=tf.keras.losses.Reduction.NONE)
        self.margin = margin

    def call(self, y_true, y_pred):
        return triplet_semihard_loss(y_true, y_pred, self.margin)

我不明白call方法是怎么回事,它返回的基类函数给出了y_truey_pred数组,但它们究竟是从哪里来的?根据Tensorflow文档指南,子类在modelcompile语句中初始化为:

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())

然后将模型拟合为:

history = model.fit(
    train_dataset,
    epochs=5)

train_dataset结构是一个包含嵌入数据和相应整数标签的元组,但是子类如何认识到这是要操作的数据呢?call方法也是隐式调用的吗?你知道吗


Tags: namemarginselftruetensorflowcall基类子类
1条回答
网友
1楼 · 发布于 2024-03-29 09:59:54

在调用类的实例时调用__call__y_truey_pred分别包含模型预测的真实标签和标签。张量流(克拉斯特遣部队)内部转换给y_true的标签,如here所示,并使用model.fit()对数据进行训练。
所有tf.keras损失都以这种形式实现,即具有两个参数y_truey_pred的函数,如here。你知道吗

相关问题 更多 >