TensorFlow:将py_func保存到.pb fi

2024-04-20 08:02:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图建立一个tensorflow模型tf.py_func公司在普通python代码中创建一部分代码。问题是,当我将模型保存到.pb文件中时,.pb文件本身非常小,不包括py_函数:0张量。当我试图从.pb文件加载并运行模型时,我得到以下错误:get ValueError:callback pyfunc_0 is not found。在

当我不保存和加载为.pb文件时,它会工作

有人能帮忙吗。这对我来说非常重要,让我睡了好几个晚上。在

model_version = "465555564"
tensorboard = TensorBoard(log_dir='./logs', histogram_freq = 0, write_graph = True, write_images = False)

sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)

def my_func(x):
    some_function

input = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [input], tf.float32)

prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": input}, {"prediction": y})
builder = saved_model_builder.SavedModelBuilder('./'+model_version)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
      sess, [tag_constants.SERVING],
      signature_def_map={
           signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature,
      },
      legacy_init_op=legacy_init_op)

builder.save()

Tags: 文件py模型inputmodelinittfdef
1条回答
网友
1楼 · 发布于 2024-04-20 08:02:34

有一种使用tf.py_func保存TF模型的方法,但是您必须使用SavedModel来保存TF模型。在

TF有两个级别的模型保存:检查点和SavedModels。有关更多详细信息,请参见this answer,但要在此处引用:

  • A checkpoint contains the value of (some of the) variables in a TensorFlow model. It is created by a Saver. To use a checkpoint, you need to have a compatible TensorFlow Graph, whose Variables have the same names as the Variables in the checkpoint.
  • SavedModel is much more comprehensive: It contains a set of Graphs (MetaGraphs, in fact, saving collections and such), as well as a checkpoint which is supposed to be compatible with these Graphs, and any asset files that are needed to run the model (e.g. Vocabulary files). For each MetaGraph it contains, it also stores a set of signatures. Signatures define (named) input and output tensors.

tf.py_funcop不能用SavedModel(在this page in the docs上注明)保存,这正是您试图在这里做的。这是有原因的。SavedModel应该完全独立于原始代码,可以用任何其他可以反序列化的语言加载。这允许模型加载诸如ML Engine之类的东西,这可能是用C++编写的或类似的东西。问题是它不能序列化任意Python代码,因此py_func是不可能的。在

您可以通过使用检查点来解决这个问题,只要您还可以继续使用Python。您将无法获得SavedModels所提供的独立性。您可以在使用tf.train.Saver训练后保存一个检查点,然后在一个新的Session中,重新构建整个图并用Saver加载它。甚至还有一种方法可以在ML引擎中使用该代码,它以前是专用于SavedModels的。在

有关在the docs中保存/恢复模型的详细信息。在

相关问题 更多 >