擅长:python、mysql、java
<p>我现在使用tf.control_依赖项保存RNN状态。这是一个例子。</p>
<pre><code> saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())
rnn_output, states = rnn(last_output, saved_states)
with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
dense_input = tf.concat(1, (last_output, rnn_output))
dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
last_output = dense_output + last_output
</code></pre>
<p>我只是确保我的图形的一部分依赖于保存状态。</p>