Tensorflow:如何替换计算图中的节点?

2024-05-23 09:42:59 发布

您现在位置:Python中文网/ 问答频道 /正文

如果您有两个不相交的图,并且要将它们链接起来,请转动:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

进入这个:

x = tf.placeholder('float')
y = f(x)
z = g(y)

有办法吗?在某些情况下,它似乎可以使施工变得更容易。

例如,如果您有一个图形,它的输入图像是一个tf.placeholder,并且想要优化输入图像,deep dream样式,那么有没有办法用一个tf.variable节点替换占位符呢?或者在建立图表之前,你必须考虑这个问题吗?


Tags: 图像图形节点链接tf图表情况样式
3条回答

TL;DR:如果您可以将这两个计算定义为Python函数,那么您应该这样做。如果不能,则在TensorFlow中有更高级的功能来序列化和导入图,这允许您从不同的源合成图。

在TensorFlow中,一种方法是将不相交的计算构建为单独的tf.Graph对象,然后使用^{}将它们转换为序列化的协议缓冲区:

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

然后您可以使用^{}gdef_1gdef_2组合成第三个图:

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]

结果是tf.train.import_meta_graph将所有附加参数传递给具有input_map参数的底层import_scoped_meta_graph,并在它自己(内部)调用import_graph_def时使用它。

它没有记录,我花了很多时间才找到它,但它的工作!

如果要组合经过训练的模型(例如在新模型中重用预训练模型的部分),可以使用Saver保存第一个模型的检查点,然后将该模型(全部或部分)还原到另一个模型中。

例如,假设要重用模型2中模型1的权重w,并将x从占位符转换为变量:

with tf.Graph().as_default() as g1:
    x = tf.placeholder('float')
    w = tf.Variable(1., name="w")
    y = x * w
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    w.initializer.run()
    # train...
    saver.save(sess, "my_model1.ckpt")

with tf.Graph().as_default() as g2:
    x = tf.Variable(2., name="v")
    w = tf.Variable(0., name="w")
    z = x + w
    restorer = tf.train.Saver([w]) # only restore w

with tf.Session(graph=g2) as sess:
    x.initializer.run()  # x now needs to be initialized
    restorer.restore(sess, "my_model1.ckpt") # restores w=1
    print(z.eval())  # prints 3.

相关问题 更多 >

    热门问题