在导出到tflite表单之前,修复冻结图形的输入节点

2024-05-23 17:49:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我可以使用以下方法冻结图形:

freeze_graph.freeze_graph(input_graph=f"{save_graph_path}/graph.pbtxt",
                          input_saver="",
                          input_binary=False,
                          input_checkpoint=last_ckpt,
                          output_node_names="network/output_node",
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph=output_frozen_graph_name,
                          clear_devices=True,
                          initializer_nodes="")

然而,该图有两个显著的输入节点,即“input/is\u training”和“input/input\u node”。你知道吗

我想将这个冻结的图形导出为tflite格式,但在这样做时,我需要将is\u training修复为False(因为它用于tf.layers.batch\u规范化). 你知道吗

我知道将is\u training占位符设置为False可以解决这个问题,但是假设我只有冻结的图形文件和检查点,我该如何进行呢?还是不可能?你知道吗


Tags: path方法namenodefalse图形inputoutput
1条回答
网友
1楼 · 发布于 2024-05-23 17:49:57

只需加载冻结的图形,将有问题的值映射到常量,然后再次保存图形,就可以实现这一点。你知道吗

import tensorflow as tf

with tf.Graph().as_default():
    # Make constant False value (name does not need to match)
    is_training = tf.constant(False, dtype=tf.bool, name="input/is_training")
    # Load frozen graph
    gd = tf.GraphDef()
    with open(f"{save_graph_path}/graph.pbtxt", "r") as f:
        gd.ParseFromString(f.read())
    # Load graph mapping placeholder to constant
    tf.import_graph_def(gd, name="", input_map={"input/is_training:0": is_training})
    # Save graph again
    tf.train.write_graph(tf.get_default_graph(), save_graph_path, "graph_modified.pbtxt",
                         as_text=True)

相关问题 更多 >