我正在用Python编写一个tensorflow函数来实现一个通用的lag。它有一个内部状态,每次运行会话时都必须更新该状态。在
下面是一个简单的1步滞后的最小示例:
def lag(x, name=None):
with tf.name_scope(name, "lag"):
zeros = tf.zeros(x.get_shape(), dtype=x.dtype)
cache = tf.Variable(zeros, name="cache")
output = tf.Variable(zeros, name="output")
output = tf.assign(output, cache)
cache = tf.assign(cache, x)
return output
让我们试试:
^{pr2}$我们得到的结果是[0, 0, 0]
,而我们想要的是[0, 1, 2]
。在
我们得到这个结果的原因是因为lag
中的最后一个cache
op没有被显式使用,而且永远不会被计算。我们可以通过在return
前添加一行来强制计算它:
output = tf.tuple([output, cache])[0]
然后我们得到预期的输出[0, 1, 2]
。但这看起来很不雅观。有没有更好的方法来迫使一个手术被评估?在
另外一个问题是,在这个实现中,我们可以看到使用了两个Variable
。我找不到另一种方法来临时复制一个张量,但我不明白为什么我不能用一个Variable
来做同样的事情。还有别的办法吗?在
目前没有回答
相关问题 更多 >
编程相关推荐