2024-04-26 04:07:14 发布
网友
tf.keras.Metric接口提供了一个有用的工具,用于实现附加指标,如损失/准确度。接口设计为在调用update_state(self, y_pred, y_true)时在批处理上更新,结果应在result(self)返回。然而,当实现一些指标,如FID、初始分数、预期校准误差时,我们必须查看整个预测集,而不是单个样本的迭代视图。我应该如何在Tensorflow中实现这些自定义度量?我应该使用另一个API吗
tf.keras.Metric
update_state(self, y_pred, y_true)
result(self)
正如我理解你的问题
class BinaryTruePositives(tf.keras.metrics.Metric): def __init__(self, name='binary_true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) self.true_positives = self.add_weight(name='tp', initializer='zeros') def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, tf.bool) y_pred = tf.cast(y_pred, tf.bool) values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) values = tf.cast(values, self.dtype) if sample_weight is not None: sample_weight = tf.cast(sample_weight, self.dtype) sample_weight = tf.broadcast_to(sample_weight, values.shape) values = tf.multiply(values, sample_weight) self.true_positives.assign_add(tf.reduce_sum(values)) def result(self): return self.true_positives
[来自tf2 docs的示例] 在update_state()方法中,我们实现了度量计算函数。本质上,update_state()为整个小批量返回单个。因此,在tf2中可以进行小批量评估。但是,为了支持批量评估指标,您还必须更新整个输入管道
update_state()
正如我理解你的问题
[来自tf2 docs的示例]
在
update_state()
方法中,我们实现了度量计算函数。本质上,update_state()
为整个小批量返回单个。因此,在tf2中可以进行小批量评估。但是,为了支持批量评估指标,您还必须更新整个输入管道相关问题 更多 >
编程相关推荐