受监控的培训课程如何运作?

2024-06-10 15:17:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图理解使用tf.Sessiontf.train.MonitoredTrainingSession之间的区别,以及我可能更喜欢其中一个而不是另一个。当我使用后者时,似乎可以避免许多“杂务”,如初始化变量、启动队列运行器或为摘要操作设置文件编写器。另一方面,对于监视的训练会话,我无法指定要显式使用的计算图。这一切在我看来都很神秘。这些课程是如何创建的,背后是否有一些我不理解的哲学?


Tags: 文件队列sessiontftrain课程区别哲学
1条回答
网友
1楼 · 发布于 2024-06-10 15:17:41

我无法对这些类是如何创建的给出一些见解,但我认为以下几点与如何使用它们有关。

tf.Session是python TensorFlow API中的低级对象,而, 正如您所说,tf.train.MonitoredTrainingSession具有许多方便的特性,特别是在大多数常见情况下非常有用。

在描述tf.train.MonitoredTrainingSession的一些好处之前,让我回答关于会话使用的图的问题。您可以使用上下文管理器指定MonitoredTrainingSession使用的tf.Graph

from __future__ import print_function
import tensorflow as tf

def example():
    g1 = tf.Graph()
    with g1.as_default():
        # Define operations and tensors in `g`.
        c1 = tf.constant(42)
        assert c1.graph is g1

    g2 = tf.Graph()
    with g2.as_default():
        # Define operations and tensors in `g`.
        c2 = tf.constant(3.14)
        assert c2.graph is g2

    # MonitoredTrainingSession example
    with g1.as_default():
        with tf.train.MonitoredTrainingSession() as sess:
            print(c1.eval(session=sess))
            # Next line raises
            # ValueError: Cannot use the given session to evaluate tensor:
            # the tensor's graph is different from the session's graph.
            try:
                print(c2.eval(session=sess))
            except ValueError as e:
                print(e)

    # Session example
    with tf.Session(graph=g2) as sess:
        print(c2.eval(session=sess))
        # Next line raises
        # ValueError: Cannot use the given session to evaluate tensor:
        # the tensor's graph is different from the session's graph.
        try:
            print(c1.eval(session=sess))
        except ValueError as e:
            print(e)

if __name__ == '__main__':
    example()

所以,正如您所说,使用MonitoredTrainingSession的好处是,这个对象负责

  • 初始化变量
  • 同时启动队列管理器
  • 设置文件写入程序

但是它也有使代码易于分发的好处,因为它的工作方式也不同,这取决于您是否将正在运行的进程指定为主进程。

例如,您可以运行类似于:

def run_my_model(train_op, session_args):
    with tf.train.MonitoredTrainingSession(**session_args) as sess:
        sess.run(train_op)

以非分布式方式调用:

run_my_model(train_op, {})`

或以分布式方式(有关输入的更多信息,请参见distributed doc):

run_my_model(train_op, {"master": server.target,
                        "is_chief": (FLAGS.task_index == 0)})

另一方面,使用原始tf.Session对象的好处是,您没有tf.train.MonitoredTrainingSession的额外好处,如果您不打算使用它们,或者希望获得更多的控制权(例如队列是如何启动的),这些好处可能很有用。

编辑(根据注释): 对于操作初始化,您必须执行如下操作(cf.official doc

# Define your graph and your ops
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_p)
    sess.run(your_graph_ops,...)

对于QueueRunner,我将向您介绍official doc,在这里您将找到更完整的示例。

编辑2:

了解tf.train.MonitoredTrainingSession如何工作的主要概念是_WrappedSession类:

This wrapper is used as a base class for various session wrappers that provide additional functionality such as monitoring, coordination, and recovery.

tf.train.MonitoredTrainingSession以这种方式工作(截至version 1.1):

  • 它首先检查它是一个主管还是一个工人(参见词法问题的distributed doc)。
  • 它开始提供的钩子(例如,StopAtStepHook在这个阶段将只检索global_step张量。
  • 它创建一个会话,它是一个Chief(或Worker会话)包装成一个_HookedSession包装成一个_CoordinatedSession包装成一个_RecoverableSession
    Chief/Worker会话负责运行由Scaffold提供的初始化操作。
      scaffold: A `Scaffold` used for gathering or building supportive ops. If
    not specified a default one is created. It's used to finalize the graph.
    
  • chief会话还负责所有检查点部分:例如,使用Scaffold中的Saver从检查点还原。
  • _HookedSession基本上是用来修饰run方法的:它在相关的时候调用_call_hook_before_runafter_run方法。
  • 在创建时,_CoordinatedSession构建一个Coordinator来启动队列运行器并负责关闭它们。
  • _RecoverableSession将确保在发生tf.errors.AbortedError时重试。

总之,tf.train.MonitoredTrainingSession避免了大量的锅炉板代码,同时使用hooks机制易于扩展。

相关问题 更多 >