使用tf.data.Datas每N步评估

2024-04-18 20:01:53 发布

您现在位置:Python中文网/ 问答频道 /正文

TensorFlow是否有某种方法可以使用tf.data.DatasetAPI在每N步训练中自动评估一个评估集?目前,我的输入函数如下所示:

def train_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(train_x), train_y))

    return (
        dataset
        .repeat()
        .shuffle(len(train_x) * 1.33))
        .batch(128)
        .make_one_shot_iterator().get_next()
    )

def eval_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(eval_x), eval_y))

    return (
        dataset
        .batch(len(eval_x)) # to use the entire eval set
        .make_one_shot_iterator().get_next()
    )

它们在tf.estimator.DNNRegressor的实例上调用,如下所示:

^{pr2}$

Tags: frominputdatalenreturntfdefeval
1条回答
网友
1楼 · 发布于 2024-04-18 20:01:53

使用不推荐使用的tf.contrib.learn.monitors.ValidationMonitor解决,如建议的in this StackOverflow answerValidationMonitor仍然可以使用monitors.replace_monitors_with_hooks实用函数在Estimator上使用。在

以下是我的实现:

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib

est = tf.Estimator.DNNRegressor(...)

validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
    input_fn=eval_input_fn,
    every_n_steps=100,
)
list_of_monitors_and_hooks = [validation_monitor]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, est)

est.train(
    input_fn=input_fn_train,
    steps=1000,
    hooks=hooks
)

相关问题 更多 >