在批处理之间传递LSTM状态的最佳方法

2024-04-18 22:34:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在寻找在批处理之间传递LSTM状态的最佳方法。我搜索了所有东西,但找不到当前实现的解决方案。想象一下我有这样的东西:

cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)

现在我想在每个批处理中高效地传递new_state,这样就不需要将其存储回内存,然后使用feed_dict重新馈送给tf。更准确地说,我找到的所有解决方案都使用sess.run来计算new_state,并使用{}将其传递到init_state。有没有什么方法可以避免使用feed-dict的瓶颈?在

我想我应该以某种方式使用tf.assign,但是文档不完整,我找不到任何解决方法。在

我要感谢每一个提前询问的人。在

干杯

弗朗西斯科·萨维里奥

我在stack overflow上找到的所有其他答案都适用于旧版本,或者使用“feed dict”方法传递新状态。例如:

1)TensorFlow: Remember LSTM state for next batch (stateful LSTM)这是通过使用“feed dict”来提供状态占位符来实现的,我想避免这种情况

2)Tensorflow - LSTM state reuse within batch这不适用于状态turple

3)Saving LSTM RNN state between runs in Tensorflow此处相同


Tags: 方法innewforsizeinit状态tf
1条回答
网友
1楼 · 发布于 2024-04-18 22:34:14

LSTMStateTuple不过是输出和隐藏状态的元组。tf.assign创建一个操作,该操作在运行时将存储在张量中的值赋给变量(如果您有特定问题,请询问以便改进文档)。通过使用元组的c属性从元组检索隐藏状态张量,可以使用tf.assign的解决方案(假设您需要隐藏状态)-new_state.c

下面是一个完整的关于玩具问题的例子:https://gist.github.com/iganichev/632b425fed0263d0274ec5b922aa3b2f

相关问题 更多 >