在同一Tensorflow会话中从Saver加载两个模型

2024-04-25 08:35:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个网络:生成输出的Model和对输出进行分级的Adversary

两者都是单独训练的,但现在我需要在一次训练中合并它们的输出。

我试图实现本文中提出的解决方案:Run multiple pre-trained Tensorflow nets at the same time

我的代码

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题

对函数model_saver.restore()的调用似乎什么也没做。在另一个模块中,我使用一个带有tf.train.Saver(tf.global_variables())的保存程序,它可以很好地恢复检查点。

模型有model.tvars = tf.trainable_variables()。为了检查发生了什么,我使用sess.run()来提取恢复前后的tvars。每次使用初始随机分配的变量而不分配检查点中的变量时。

有没有想过为什么model_saver.restore()看起来无所事事?


Tags: thenamemodeltfargstrainrestorevariables
3条回答

请检查:

adv_varlist = {v.name.lstrip("avd/")[:-2]: v 

应该是“adv”,不是吗

标记为正确的答案并没有告诉我们如何将两个不同的模型显式加载到一个会话中,下面是我的答案:

  1. 为要加载的模型创建两个不同的名称作用域。

  2. 初始化两个保存程序,它们将加载两个不同网络中变量的参数。

  3. 从相应的检查点文件加载。

with tf.Session() as sess:
    with tf.name_scope("net1"):
      net1 = Net1()
    with tf.name_scope("net2"):
      net2 = Net2()

    net1_varlist = {v.op.name.lstrip("net1/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
    net1_saver = tf.train.Saver(var_list=net1_varlist)

    net2_varlist = {v.op.name.lstrip("net2/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
    net2_saver = tf.train.Saver(var_list=net2_varlist)

    net1_saver.restore(sess, "net1.ckpt")
    net2_saver.restore(sess, "net2.ckpt")

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它。

为了诊断这个问题,我手动循环遍历每个变量,并逐个分配它们。然后我注意到,分配变量后,名称会改变。这里是这样描述的:TensorFlow checkpoint save and read

根据那篇文章中的建议,我用自己的图表运行了每个模型。这也意味着我必须在它自己的会话中运行每个图。这意味着以不同的方式处理会话管理。

首先我创建了两个图

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

然后是两会

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我初始化每个会话中的变量,并分别还原每个图

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

从这里开始,每当需要每个会话时,我都会用with sess.as_default():包装该会话中的任何tf函数。最后,我手动关闭会话

sess.close()
adv_sess.close()

相关问题 更多 >