列车保护器在不同的计算机上加载最新的检查点

2024-04-24 18:41:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个经过训练的模型,它是用tf.train.Saver保存的,生成4个相关文件

  • checkpoint
  • model_iter-315000.data-00000-of-00001
  • model_iter-315000.index
  • model_iter-315000.meta

既然它是通过docker容器生成的,那么机器本身和docker上的路径是不同的,就像我们在两台不同的机器上工作一样。在

我正在尝试将保存的模型加载到容器外部。在

当我运行以下程序时

sess = tf.Session()
saver = tf.train.import_meta_graph('path_to_.meta_file_on_new_machine')  # Works
saver.restore(sess, tf.train.latest_checkpoint('path_to_ckpt_dir_on_new_machine')  # Fails

错误是

tensorflow.python.framework.errors_impl.NotFoundError: PATH_ON_OLD_MACHINE; No such file or directory

即使我在调用tf.train.latest_checkpoint时提供了新路径,但仍会得到错误,显示旧路径上的路径。在

我怎么解决这个问题?在


Tags: topathdocker模型路径机器modeltf
3条回答

这里的方法不需要编辑检查点文件或手动查看检查点目录。如果我们知道检查点前缀的名称,我们可以使用regex并假设tensorflow在checkpoint文件的第一行写入最新的检查点:

import tensorflow as tf
import os
import re


def latest_checkpoint(ckpt_dir, ckpt_prefix="model.ckpt", return_relative=True):
    if return_relative:
        with open(os.path.join(ckpt_dir, "checkpoint")) as f:
            text = f.readline()
        pattern = re.compile(re.escape(ckpt_prefix + "-") + r"[0-9]+")
        basename = pattern.findall(text)[0]
        return os.path.join(ckpt_dir, basename)
    else:
        return tf.train.latest_checkpoint(ckpt_dir)

“checkpoint”文件是一个索引文件,它本身就嵌入了路径。在文本编辑器中打开它并将路径更改为正确的新路径。在

或者,使用^{}加载特定的检查点,而不是依赖TensorFlow为您找到最新的检查点。在这种情况下,它不会引用“checkpoint”文件,不同的路径也不会成为问题。在

或者编写一个小脚本来修改“checkpoint”的内容。在

如果打开checkpoint文件,您将看到如下内容:

model_checkpoint_path: "/PATH/ON/OLD/MACHINE/model.ckpt-315000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-300000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-285000"
[...]

只要删除/PATH/ON/OLD/MACHINE/,或者用/PATH/ON/NEW/MACHINE/替换它,就可以开始了。在

编辑: 将来,在创建tf.train.Saver时,应该使用save_relative_paths选项。引用doc

save_relative_paths: If True, will write relative paths to the checkpoint state file. This is needed if the user wants to copy the checkpoint directory and reload from the copied directory.

相关问题 更多 >