将TF2对象检测API模型转换为冻结图

2024-06-10 09:44:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用Tensorflow对象检测API训练模型ssd_resnet50_v1_fpn_640x640_coco17_tpu-8https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py

将其导出到保存模型后: .\exporter_main_v2.py --input_type image_tensor --pipeline_config_path .\models\my_ssd_resnet50_v1_fpn\pipeline.config --trained_checkpoint_dir .\models\my_ssd_resnet50_v1_fpn\ --output_directory .\exported-models\models\Bel_model使用 https://github.com/tensorflow/models/blob/master/research/object_detection/exporter_main_v2.py

在这一步中,使用Tensorflow可以很好地进行推理。来自保存的模型和检查点。此代码用于测试推断: https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip

在我尝试使用这种方法将保存的模型转换为冻结图以在OpenCV中使用它之后 https://github.com/opencv/opencv/issues/16879#issuecomment-603815872

import tensorflow as tf
print(tf.__version__)

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

loaded = tf.saved_model.load('models/mnist_test')
infer = loaded.signatures['serving_default']

f = tf.function(infer).get_concrete_function(flatten_input=tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32))
f2 = convert_variables_to_constants_v2(f)
graph_def = f2.graph.as_graph_def()

# Export frozen graph
with tf.io.gfile.GFile('frozen_graph.pb', 'wb') as f:
   f.write(graph_def.SerializeToString())

不幸的是,在此步骤中,我收到错误:

Traceback (most recent call last):
  File ".\frozen_graph.py", line 8, in <module>
    f = tf.function(infer).get_concrete_function(input_1=tf.TensorSpec(shape=[None, 640, 640, 3], dtype=tf.float32))
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\framework\func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1669 __call__  *
        return self._call_impl(args, kwargs)
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1685 _call_impl  **
        raise structured_err
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1678 _call_impl
        return self._call_with_structured_signature(args, kwargs,
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1756 _call_with_structured_signature
        self._structured_signature_check_missing_args(args, kwargs)
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1775 _structured_signature_check_missing_args
        raise TypeError("{} missing required arguments: {}".format(

    TypeError: signature_wrapper(*, input_tensor) missing required arguments: input_tensor

请帮我解决这个问题。 也许你可以建议我另一种创建冻结图的方法。 使用Keras训练模型是否可能更简单


Tags: inpylibpackagestensorflowsiteargsfunction