如何获取Keras模型中tensorflow输出节点的名称?

2024-05-15 10:10:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图从我的Keras(tensorflow backend)模型创建一个pb文件,以便在iOS上构建它。我正在使用freeze.py,需要传递输出节点。如何获取Keras模型的输出节点的名称?

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py


Tags: 文件pyhttps模型github名称combackend
3条回答

可以使用Keras model.summary()获取最后一层的名称。

如果model.outputs不为空,则可以通过以下方式获取节点名称:

[node.op.name for node in model.outputs]

你可以通过

session = keras.backend.get_session()

然后通过

min_graph = convert_variables_to_constants(session, session.graph_def, [node.op.name for node in model.outputs])

之后,您可以通过

tensorflow.train.write_graph(min_graph, "/logdir/", "file.pb", as_text=True)

如果在Keras中构造模型时未显式指定输出节点,则可以按如下方式打印它们:

[print(n.name) for n in tf.get_default_graph().as_graph_def().node]

然后你需要做的就是找到正确的一个,它通常类似于激活函数的名称。您可以使用这个字符串名作为freeze_graph函数中output_node_names的值。

您还可以使用tensorflow实用程序:summarize_graph查找可能的output_nodes。从official documentation

Many of the transforms that the tool supports need to know what the input and output layers of the model are. The best source for these is the model training process, where for a classifier the inputs will be the nodes that receive the data from the training set, and the output will be the predictions. If you're unsure, the summarize_graph tool can inspect the model and provide guesses about likely input and output nodes, as well as other information that's useful for debugging.

它只需要保存的图形pb文件作为输入。查看文档中的示例。

相关问题 更多 >