Tensorflow力操作执行

2024-04-20 00:03:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用Python编写一个tensorflow函数来实现一个通用的lag。它有一个内部状态,每次运行会话时都必须更新该状态。在

下面是一个简单的1步滞后的最小示例:

def lag(x, name=None):
    with tf.name_scope(name, "lag"):
        zeros = tf.zeros(x.get_shape(), dtype=x.dtype)
        cache = tf.Variable(zeros, name="cache")
        output = tf.Variable(zeros, name="output")
        output = tf.assign(output, cache)
        cache = tf.assign(cache, x)
        return output

让我们试试:

^{pr2}$

我们得到的结果是[0, 0, 0],而我们想要的是[0, 1, 2]。在

我们得到这个结果的原因是因为lag中的最后一个cacheop没有被显式使用,而且永远不会被计算。我们可以通过在return前添加一行来强制计算它:

output = tf.tuple([output, cache])[0]

然后我们得到预期的输出[0, 1, 2]。但这看起来很不雅观。有没有更好的方法来迫使一个手术被评估?在

另外一个问题是,在这个实现中,我们可以看到使用了两个Variable。我找不到另一种方法来临时复制一个张量,但我不明白为什么我不能用一个Variable来做同样的事情。还有别的办法吗?在


Tags: 方法函数name示例cacheoutputreturn状态