TensorArray堆栈操作引发值

2024-04-19 08:36:04 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在while_loop中使用TensorArray,其中循环的每次迭代都填充TensorArray中的一个项。下面是一个最小的示例:

ta = tensor_array_ops.TensorArray(size=4, tensor_array_name='output_ta', dtype=tf.float32)
time = tf.constant(0)

def _call(time, ta):
    ta.write(time, tf.constant([1.,2.,3.,4.]))
    return (time+1, ta)

_, t_out = tf.while_loop(
    cond=lambda time, _: time < 4,
    body=_call,
    loop_vars=(time, ta)
)

现在,这段代码运行良好。但是,一旦我尝试使用t_out执行任何操作,就会出现一个错误,例如

^{pr2}$

有人能看出我的代码有什么问题吗?在

编辑:这似乎只发生在急切模式下。如果有人知道我如何修复它,让它在急切模式下工作,那就太好了。在


Tags: loop示例sizetimetf模式callout