这可能非常简单,但我找不到答案。我试着在tf.while_循环. 为了简单起见,我只传递了(3,4)形张量‘x’,暂时,在‘body’函数中什么也不做。但这一论点的通过似乎引发了一些问题。堆栈跟踪只告诉“AssertionError:”。请帮忙。 代码:
import tensorflow as tf
import numpy as np
def cond(sequence_len, step, x):
return tf.less(step,sequence_len)
def body(sequence_len, step, x):
return (sequence_len, step+1)
step = tf.constant(0)
sequence_len = tf.constant(10)
x = tf.zeros([3, 4], tf.int32)
res,step = tf.while_loop(cond,body,[sequence_len, step, x])
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
step_eval = step.eval(session=sess)
print(step_eval)
完整的堆栈跟踪也粘贴在下面。 The image of the stack trace
tf.while_loop()您需要确保body()是一个可调用的函数,它获取一个张量列表,并返回一个长度相同的张量列表,并以相同的类型作为输入。这就是While循环的工作原理。每个返回都作为输入参数发回。也就是说,上一次返回是下一次迭代的输入参数。在
相关问题 更多 >
编程相关推荐