我正在寻找在批处理之间传递LSTM状态的最佳方法。我搜索了所有东西,但找不到当前实现的解决方案。想象一下我有这样的东西:
cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)
现在我想在每个批处理中高效地传递new_state
,这样就不需要将其存储回内存,然后使用feed_dict
重新馈送给tf。更准确地说,我找到的所有解决方案都使用sess.run
来计算new_state
,并使用{init_state
。有没有什么方法可以避免使用feed-dict
的瓶颈?在
我想我应该以某种方式使用tf.assign
,但是文档不完整,我找不到任何解决方法。在
我要感谢每一个提前询问的人。在
干杯
弗朗西斯科·萨维里奥
我在stack overflow上找到的所有其他答案都适用于旧版本,或者使用“feed dict”方法传递新状态。例如:
1)TensorFlow: Remember LSTM state for next batch (stateful LSTM)这是通过使用“feed dict”来提供状态占位符来实现的,我想避免这种情况
2)Tensorflow - LSTM state reuse within batch这不适用于状态turple
LSTMStateTuple
不过是输出和隐藏状态的元组。tf.assign
创建一个操作,该操作在运行时将存储在张量中的值赋给变量(如果您有特定问题,请询问以便改进文档)。通过使用元组的c
属性从元组检索隐藏状态张量,可以使用tf.assign
的解决方案(假设您需要隐藏状态)-new_state.c
下面是一个完整的关于玩具问题的例子:https://gist.github.com/iganichev/632b425fed0263d0274ec5b922aa3b2f
相关问题 更多 >
编程相关推荐