在十位数内找到所有检查点的路径

2024-04-19 09:32:59 发布

您现在位置:Python中文网/ 问答频道 /正文

到目前为止,我只在Tensorflow中使用保存和加载检查点来加载最后一个检查点。通常我使用的代码如下:

ckpt = tf.train.get_checkpoint_state(load_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(session, ckpt.model_checkpoint_path)
else:
    tf.gfile.DeleteRecursively(load_dir)
    tf.gfile.MakeDirs(load_dir)

然而,在我最近的实验中,我每1000次迭代就保存一个检查点,我想对所有检查点运行一个评估脚本,例如,展示不同的验证度量是如何进行的。在Tensorflow中有没有任何简单的方法来获取所有检查点,或者我只需要使用os来遍历所有的名称?在


Tags: path代码getmodeliftftensorflowdir
1条回答
网友
1楼 · 发布于 2024-04-19 09:32:59

代码段中的ckpt对象是CheckpointState协议缓冲区。您不必访问最新的模型路径(ckpt.model_checkpoint_path),而是可以使用以下方法迭代所有这些路径:

for model_path in ckpt.all_model_checkpoint_paths:
    saver.restore(session, model_path)
    # Do the evaluation using the restored model

相关问题 更多 >