Tensorflow:如何保存变量并将它们加载到不同的变量?

2024-05-08 13:41:45 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有两个相同的网络,AB。我保存了(使用Saver)network A的前一个状态,现在我想把它加载到network B(所有这些都发生在同一次运行中)。我该怎么做?你知道吗


Tags: 网络状态networksaver
1条回答
网友
1楼 · 发布于 2024-05-08 13:41:45

让我举个例子。首先,让我们定义并保存一些变量:

import tensorflow as tf

v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmp.ckpt')

现在,让我们在一个新图形中定义一些具有相同名称的变量,并从检查点加载它们的值:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros(1), name='v1')
    v2 = tf.Variable(tf.zeros(1), name='v2')

    saver = tf.train.Saver(tf.trainable_variables())
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, './tmp.ckpt')
        print(sess.run([v1, v2]))

最后一行打印:

[array([1.], dtype=float32), array([2.], dtype=float32)]

相关问题 更多 >

    热门问题