如何使用SessionRunHook打印张量tf.data.Dataset应用程序编程接口?

2024-04-26 02:29:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用tf.data.Dataset并为传递给Dataset.map的闭包中的操作分配名称,如下所示

import tensorflow as tf


def model_fn(features, mode):
    loss = tf.constant(1)
    train_op = tf.no_op()
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


def input_fn():
    dataset = tf.data.Dataset \
        .from_generator(lambda: (x*x for x in range(10)), tf.int32) \
        .map(lambda x: tf.identity(x, name='tokens_inside'))

    ret = dataset.make_one_shot_iterator().get_next()
    tf.identity(ret, 'tokens_outside')

    return ret


tf.logging.set_verbosity(tf.logging.INFO)

hooks = [
    tf.train.LoggingTensorHook(['tokens_outside'], every_n_iter=1),
    tf.train.LoggingTensorHook(['tokens_inside'], every_n_iter=1),
]

est = tf.estimator.Estimator(model_fn=model_fn, model_dir='mout')
est.train(input_fn=input_fn, hooks=hooks, max_steps=1)

当使用tf.train.LoggingTensorHook转储一些值时,第二个钩子抛出一个异常:

我遇到这样的错误:

^{pr2}$

我想Dataset操作会为每个函数创建一个新的图形?有没有一种方法可以定制tf.train.LoggingTensorHook以便它知道要搜索哪个图来命名张量?在


Tags: mapinputdatamodelmodetftraindataset