tf.while_循环在p中运行时给出错误的结果

2024-04-16 13:00:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我想逐行更新tensorflow中tf.while_loop中的二维tf.variable。为此,我使用tf.assign方法。问题是我的实现和parallel_iterations>1的结果是错误的。使用parallel_iterations=1结果是正确的。代码是这样的:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(a[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

迭代是完全独立的,我不知道是什么问题。在

奇怪的是,如果我像这样更改代码:

^{pr2}$

代码给出parallel_iterations>1的正确结果。有人能解释一下这是怎么回事,给我一个有效的解决方案来更新变量,因为我想更新的原始变量很大,而我发现的解决方案效率很低。在


Tags: 代码loopreturnparallelvartfdefbody
1条回答
网友
1楼 · 发布于 2024-04-16 13:00:00

不需要使用变量,只需在循环体上生成行更新张量:

import tensorflow as tf

def method(i):
    # Placeholder logic
    return tf.cast(tf.range(i, i + 100), tf.float32)

def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    # Produce new row
    updated_row = method(i)
    # Index vector that is 1 only on the row to update
    idx = tf.equal(tf.range(tf.shape(a)[0]), i)
    idx = tf.cast(idx[:, tf.newaxis], var.dtype)
    # Compose the new tensor with the old one and the new row
    var_updated = (1 - idx) * var + idx * updated_row
    return [tf.add(i, 1), var_updated]

# Start with zeros
a = tf.zeros([100, 100], tf.float32)
i = tf.constant(0)
i_end, a_updated = tf.while_loop(condition, body, [i, a], parallel_iterations=10)

with tf.Session() as sess:
    print(sess.run(a_updated))

输出:

^{pr2}$

相关问题 更多 >