<p>下面是通过定义状态变量来更新LSTM初始状态的代码。它还支持多个层。</p>
<p>我们定义了两个函数-一个用于获取初始状态为零的状态变量,另一个用于返回操作,我们可以传递给<code>session.run</code>,以便用LSTM的最后一个隐藏状态更新状态变量。</p>
<pre class="lang-py prettyprint-override"><code>def get_state_variables(batch_size, cell):
# For each layer, get the initial state and make a variable out of it
# to enable updating its value.
state_variables = []
for state_c, state_h in cell.zero_state(batch_size, tf.float32):
state_variables.append(tf.contrib.rnn.LSTMStateTuple(
tf.Variable(state_c, trainable=False),
tf.Variable(state_h, trainable=False)))
# Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
return tuple(state_variables)
def get_state_update_op(state_variables, new_states):
# Add an operation to update the train states with the last state tensors
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# Assign the new state to the state variables on this layer
update_ops.extend([state_variable[0].assign(new_state[0]),
state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
return tf.tuple(update_ops)
</code></pre>
<p>我们可以使用它在每个批处理之后更新LSTM的状态。请注意,我使用<code>tf.nn.dynamic_rnn</code>展开:</p>
<pre class="lang-py prettyprint-override"><code>data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)
# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})
</code></pre>
<p>与<a href="https://stackoverflow.com/a/38053807/2628369">this answer</a>的主要区别在于<code>state_is_tuple=True</code>使LSTM的状态成为包含两个变量(单元状态和隐藏状态)的LSTMStateTuple,而不仅仅是一个变量。然后使用多个层使LSTM的状态成为LSTMStateTuples的元组-每层一个。</p>
<h3>重置为零</h3>
<p>使用经过训练的模型进行预测/解码时,可能需要将状态重置为零。然后,您可以使用此功能:</p>
<pre class="lang-py prettyprint-override"><code>def get_state_reset_op(state_variables, cell, batch_size):
# Return an operation to set each variable in a list of LSTMStateTuples to zero
zero_states = cell.zero_state(batch_size, tf.float32)
return get_state_update_op(state_variables, zero_states)
</code></pre>
<p>例如,如上所述:</p>
<pre class="lang-py prettyprint-override"><code>reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})
</code></pre>