您好,我想将现有模型的输入数据类型从uint8更改为float,我加载了模型,但我不知道如何修改输入数据类型,请帮助我
def load( checkpoint_filename, input_name="images",
output_name="features"):
session = tf.compat.v1.Session()
with tf.io.gfile.GFile(checkpoint_filename, "rb") as file_handle:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(file_handle.read())
tf.import_graph_def(graph_def, name="net")
input_var = tf.compat.v1.get_default_graph().get_tensor_by_name(
"net/%s:0" % input_name)
print(input_var.dtype.as_numpy_dtype)
load('./converttf2tflite/mars-small128TOtrt.pb')
目前没有回答
相关问题 更多 >
编程相关推荐