从Tensorflow 2.0导出pb文件以与C API一起使用

2024-04-20 09:49:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用Tensorflow 2.0训练了一个模型,并希望将其用于C API

我知道我的C API实现是正确的,因为其他(我下载的)PB文件似乎工作得很好

但使用这些命令生成的pb文件不起作用:

model.save("./model")
#or
tf.saved_model.save(model, "./model")

每当我尝试用带有的C API加载它时,就会得到一个“Invalid GraphDef”状态

TF_Buffer* GraphDefinition = ReadBufferFromFile("./model/saved_model.pb");
[...]
TF_GraphImportGraphDef(Graph, GraphDefinition, GraphDefOptions, Status);

我尝试了freeze_graph thing,但在TF2.0上似乎不起作用。。。在这个问题上,大量的资源已经过时了。我假设TF2.0生成的PB文件可能与TF1.x生成的文件格式不同

那么,在C/C++环境中运行此模型的选项是什么

(我不喜欢用bazel之类的东西编译东西,在我的例子中,C API DLL非常方便)


Tags: or文件模型命令apimodelsavetf
1条回答
网友
1楼 · 发布于 2024-04-20 09:49:52

我找到了解决办法!(与谷歌colab合作)

首先,我使用Tensorflow v1.x和

%tensorflow_version 1.x

可以肯定的是:

import tensorflow
print(tensorflow.__version__)
# return 1.15.2

我基本上只是复制粘贴的this code来加载我的网络的.h5版本

model.save("model.h5")

相关问题 更多 >