我应该如何实现对整个预测进行计算的tf.keras.Metric?

2024-04-26 04:07:14 发布

您现在位置:Python中文网/ 问答频道 /正文

tf.keras.Metric接口提供了一个有用的工具,用于实现附加指标,如损失/准确度。接口设计为在调用update_state(self, y_pred, y_true)时在批处理上更新,结果应在result(self)返回。然而,当实现一些指标,如FID、初始分数、预期校准误差时,我们必须查看整个预测集,而不是单个样本的迭代视图。我应该如何在Tensorflow中实现这些自定义度量?我应该使用另一个API吗


1条回答
网友
1楼 · 发布于 2024-04-26 04:07:14

正如我理解你的问题

class BinaryTruePositives(tf.keras.metrics.Metric):

  def __init__(self, name='binary_true_positives', **kwargs):
    super(BinaryTruePositives, self).__init__(name=name, **kwargs)
    self.true_positives = self.add_weight(name='tp', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
      sample_weight = tf.cast(sample_weight, self.dtype)
      sample_weight = tf.broadcast_to(sample_weight, values.shape)
      values = tf.multiply(values, sample_weight)
    self.true_positives.assign_add(tf.reduce_sum(values))

  def result(self):
    return self.true_positives

[来自tf2 docs的示例]
update_state()方法中,我们实现了度量计算函数。本质上,update_state()为整个小批量返回单个。因此,在tf2中可以进行小批量评估。但是,为了支持批量评估指标,您还必须更新整个输入管道

相关问题 更多 >

    热门问题