使用时如何获取全局步长列车监控训练课程

2024-05-23 16:09:26 发布

您现在位置:Python中文网/ 问答频道 /正文

当我们在Saver.save中指定全局步骤时,它将把全局步骤存储为检查点后缀。在

# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)

我们可以恢复检查点并获得存储在检查点中的最后一个全局步骤,如下所示:

^{pr2}$

如果我们使用tf.train.MonitoredTrainingSession,那么将全局步骤保存到检查点并获得gstep的等效方法是什么?在

编辑1

根据Maxim的建议,我在tf.train.MonitoredTrainingSession之前创建了global_step变量,并添加了一个CheckpointSaverHook,如下所示:

global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
                                                    save_steps=5,
                                                    checkpoint_basename=(checkpoints_prefix + ".ckpt"))

with tf.train.MonitoredTrainingSession(master=server.target,
                                       is_chief=is_chief,                     
                                       hooks=[sync_replicas_hook, save_checkpoint_hook],
                                       config=config) as session:

    _, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
    print("current global step=" + str(gstep))

我可以看到它生成的检查点文件类似于Saver.saver所做的。但是,它无法从检查点检索全局步骤。请告诉我该怎么解决这个问题?在


Tags: sessionsavetfstep步骤train全局global
1条回答
网友
1楼 · 发布于 2024-05-23 16:09:26

您可以通过^{}或通过^{}函数获得当前全局步骤。在培训开始前,应通知后者。在

对于被监视的会话,将^{}添加到hooks,后者在内部使用定义的全局步骤张量在每N个步骤之后保存模型。在

相关问题 更多 >