将张量的数据类型(从现有模型)从uint8更改为float

2024-04-26 04:46:18 发布

您现在位置:Python中文网/ 问答频道 /正文

您好,我想将现有模型的输入数据类型从uint8更改为float,我加载了模型,但我不知道如何修改输入数据类型,请帮助我

def load( checkpoint_filename, input_name="images",
             output_name="features"):
    session = tf.compat.v1.Session()
    with tf.io.gfile.GFile(checkpoint_filename, "rb") as file_handle:
        graph_def =  tf.compat.v1.GraphDef()
        graph_def.ParseFromString(file_handle.read())
    tf.import_graph_def(graph_def, name="net")
    input_var = tf.compat.v1.get_default_graph().get_tensor_by_name(
        "net/%s:0" % input_name)
    print(input_var.dtype.as_numpy_dtype)
load('./converttf2tflite/mars-small128TOtrt.pb')

Tags: name模型inputtfdefasloadfilename