如何使用python知道tflite模型中的Conv2D参数?

2024-06-01 05:59:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我能够通过下面的代码找到tflite模型中每个Conv2D的输入/输出张量形状

import tensorflow as tf

SAVED_MODEL_PATH = "TFLITEMODEL_PATH.tflite"
interpreter = tf.lite.Interpreter(model_path=SAVED_MODEL_PATH)

ops = interpreter._get_ops_details()
for op_index, op in enumerate(ops):
    if op['op_name'] == "CONV_2D":
        cnt += 1
        for tensor_idx in op['inputs']:
            tensor = interpreter2._get_tensor_details(tensor_idx)
            tensor_shape = tensor['shape']
            print(tensor['name'], "\t", tensor['shape'])
        print("----")

下面是输出

Placeholder      [  1 224 224   3]
conv2d/kernel    [64  7  7  3]
conv2d/Conv2D_bias   [64]
----
block-0/denseblock-0-0/Relu      [ 1 56 56 64]
block-0/denseblock-0-0/conv2d/kernel     [32  3  3 64]
block-0/denseblock-0-0/conv2d/Conv2D_bias    [32]
----
block-0/denseblock-0-1/Relu      [ 1 56 56 96]
block-0/denseblock-0-1/conv2d/kernel     [32  3  3 96]
block-0/denseblock-0-1/conv2d/Conv2D_bias    [32]
----

但我想知道如何用python代码知道它的Conv2D参数(如填充、跨步、膨胀等)。我想要像netron.app这样的信息。它显示所有层及其信息,如名称、填充、步幅等。 enter image description here


Tags: path代码tfblockkernelopstensorshape
1条回答
网友
1楼 · 发布于 2024-06-01 05:59:25

没有官方的方法可以做到这一点_get_ops_details不是公共API,也不能保证稳定

我可以知道你想达到什么目的吗

从技术上讲,您可以深入细节,自己解析TFLite FlatBuffer模型。然而,这也不是一条正式的道路

相关问题 更多 >