了解TensorFlow检查点加载吗?

2024-04-18 08:07:31 发布

您现在位置:Python中文网/ 问答频道 /正文

TF检查站里有什么?例如,估计器存储一个单独的文件,其中包含GraphDef原型,您基本上可以做一个tf.import_graph_def(),然后创建一个tf.train.Saver()并将一个检查点还原到图中。现在,如果您有另一个GraphDef来描述一个完全不同的图,它恰好共享完全相同的变量名和匹配的变量维,那么您是否能够将检查点加载到该图中?换言之,它只是一个变量名到值的映射,还是假设了在加载过程中要检查的图形的其他内容?如果您尝试将检查点加载到作为原始图子集的图中(即张量维度和名称匹配,但缺少一些名称),该怎么办?在


Tags: 文件import名称图形过程tfdeftrain
1条回答
网友
1楼 · 发布于 2024-04-18 08:07:31

人们什么时候开始阅读文档(?): https://www.tensorflow.org/mobile/prepare_models

这些是不同的概念。只要形状匹配,就可以加载权重。如果有一场比赛失利,你会得到:

Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.

但是,您可以调整一个非常重要的情况,其中图形完全不同:

import tensorflow as tf
import numpy as np

test_data = np.arange(4).reshape(1, 2, 2, 1)

# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(output, {input: test_data}))
  saver = tf.train.Saver()
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print(tf.trainable_variables())

# reset previous elements
tf.reset_default_graph()

# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.restore(sess, "/tmp/model.ckpt")
  print(sess.run(output, {input: test_data}))

就我而言,我得到了:

^{pr2}$

相关问题 更多 >