是否可以从冻结图中删除批次维度?

2024-04-19 19:01:00 发布

您现在位置:Python中文网/ 问答频道 /正文

检查冻结张量流模型:

wget https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz

我看到输入大小是Tensor 'input:0', which has shape '(1, 299, 299, 3)',我想知道是否有可能使输入(None, 299, 299, 3)以使批大小大于1的批预测可用?你知道吗


Tags: httpsorg模型commodelsdownloadtensorflowstorage
1条回答
网友
1楼 · 发布于 2024-04-19 19:01:00

在一般情况下,可能不可能这样做,因为可能存在依赖于第一维度为1的操作(例如,假设^{}用于input:0)。但是,可以尝试用所需形状的占位符替换输入。你可以用^{}来做这个。如果操作允许,那么TensorFlow应该导入相应调整节点形状的图。请参见以下示例:

import tensorflow as tf

# First graph
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [1, 10, 20], name='Input')
    y = tf.square(x, name='Output')
    print(y)
    # Tensor("Output:0", shape=(1, 10, 20), dtype=float32)
    gd = tf.get_default_graph().as_graph_def()

# Second graph
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [None, 10, 20], name='Input')
    y, = tf.graph_util.import_graph_def(gd, input_map={'Input:0': x},
                                        return_elements=['Output:0'], name='')
    print(y)
    # Tensor("Output:0", shape=(?, 10, 20), dtype=float32)

在第一个图中,Output:0节点有一个(1, 10, 20)形状,这是从Input:0张量的形状推断出来的。但是,当我从第一个图中获取图定义并加载到第二个图中时,用一个第一维度未定义的占位符替换Input:0张量,Output:0的形状被更新为(?, 10, 20)。如果我在第二个图中运行操作,给出第一个维度大于1的输入值,它将按预期工作,因为该图是正确的。你知道吗

相关问题 更多 >