如何将张量类型的值传递给变量的shape参数?

2024-04-19 21:35:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我遇到了一个问题,可以总结如下:

foo = tf.constant(3)
foo_variable = tf.get_variable("foo", shape=[foo], dtype=tf.int32)

变量的形状必须依赖于张量的值(foo这里只是从其他运算中提取计算结果)

这里的错误是The shape of a variable can not be a Tensor object

如何解决?你知道吗


Tags: ofthegetfootf错误notvariable
2条回答

创建一个由foo张量指定形状的张量初始值设定项,然后使用此初始值设定项用validate_shape=False实例化一个新变量:

import tensorflow as tf

x = tf.placeholder(tf.int32, shape=())
shape = tf.constant([2, 3]) + x
init = tf.zeros(shape, dtype=tf.int32)
v = tf.get_variable('foo', initializer=init, validate_shape=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer(), {x: 1})
    print(v.eval())
    # [[0 0 0 0]
    #  [0 0 0 0]
    #  [0 0 0 0]]

tensorflow变量不能具有动态形状,但如果您知道会话外的形状,则可以使用:

foo_variable = tf.get_variable("foo", shape=[], validate_shape=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(foo_variable, feed_dict={foo_variable: ones((2,2))}))

相关问题 更多 >