tf到onnx转换后所需的输入输出张量名称

2024-05-23 14:24:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在编写一个软件,它从模型提供者那里接收一个ONNX模型。最终目标是将此模型转换为Tensorflow,这就是它在运行时的使用方式。现有的TF infra和其他深度学习模型都在TF中

在后端,使用ONNX_TF工具将导入的ONNX模型转换并导出为TF graph pb文件。然后我们使用import_graph_def()将模型加载到GraphDef。这样,我们将能够在模型加载到图中时获得默认图

但是,我当前的问题是,从Tensorflow GraphDef对象来看,不清楚图形的输入和输出节点是什么。我看到这样的解决方案:

input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)

然而,在将tf-pf文件加载到GraphDef时,我所拥有的只是pb文件,而不是张量名称。我不知道这个问题怎么解决。在将TF或PyTorch模型转换为ONNX时,我是否必须要求模型提供者标记其输入和输出节点?或者有一种方法可以从ONNX图形中检索信息?理想情况下,我们不希望在保存格式、版本等方面限制用户编写原始深度学习模型的方式(在onnx转换之前)

系统信息:

Tensorflow Version: 1.15
Python version: 3.6
onnx==1.7.0

Tags: 文件模型图形getby节点tftensorflow