<p>基于keveman的回答,我创建了一个python脚本,您可以执行该脚本来重命名任何TensorFlow检查点的变量:</p>
<p><a href="https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96" rel="noreferrer">https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96</a></p>
<p>可以替换变量名中的子字符串,并为所有名称添加前缀。用调用脚本</p>
<pre><code>python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir
</code></pre>
<p>带可选参数</p>
<pre><code>--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run
</code></pre>
<p>以下是脚本的核心功能:</p>
<pre><code>def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run=False):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
</code></pre>
<p>示例:</p>
<pre><code>python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/
</code></pre>
<p>将变量<code>scope1/Variable1</code>重命名为<code>abc/scope1/model/Variable1</code>。</p>