副作用tf.while_循环

2024-04-19 16:45:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我现在很难理解tensorflow是如何工作的,我觉得python接口有些模糊。在

我最近试着在里面打印一份简单的声明tf.while_循环,还有很多事情我还不清楚:

import tensorflow as tf

nb_iter = tf.constant(value=10)
#This solution does not work at all
#nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    tf.Print(i, [i], message='Another iteration')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

注意,如果我用

^{pr2}$

我得到了以下错误:

ValueError: Shape must be rank 0 but is rank 1 for 'while/LoopCond' (op: 'LoopCond') with input shapes: [1].

当我尝试使用'I'索引索引一个张量时,情况会变得更糟(这里没有显示示例),然后我得到以下错误

alueError: Operation 'while/strided_slice' has been marked as not fetchable.

有人能给我指一份文件来解释tf.while_循环与一起使用时有效tf.变量,以及是否可以在循环内部使用副作用(如打印),以及用循环变量索引张量?在

提前感谢您的帮助


Tags: loopgettftensorflowasnotresvariable
1条回答
网友
1楼 · 发布于 2024-04-19 16:45:12

我的第一个例子实际上有很多问题:

在tf.打印如果运算符没有副作用(即i=tf.打印())

如果布尔值是标量,那么它就是秩0张量,而不是秩1张量。。。在

以下是有效的代码:

import tensorflow as tf

#nb_iter = tf.constant(value=10)
#This solution does not work at all
nb_iter = tf.get_variable('nb_iter', shape=(), dtype=tf.int32, trainable=False,
                          initializer=tf.zeros_initializer())
nb_iter = tf.add(nb_iter,10)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
v = tf.get_variable('v', shape=(10), trainable=False,
                     initializer=tf.random_uniform_initializer, dtype=tf.float32)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    i = tf.Print(i, [v[i]], message='Another vector element: ')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

输出:

^{pr2}$

相关问题 更多 >