我有一个形状为[batch, None, dim]
的三维张量,其中第二维,即时间步,是未知的。我使用dynamic_rnn
来处理这样的输入,如下片段所示:
import numpy as np
import tensorflow as tf
batch = 2
dim = 3
hidden = 4
lengths = tf.placeholder(dtype=tf.int32, shape=[batch])
inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim])
cell = tf.nn.rnn_cell.GRUCell(hidden)
cell_state = cell.zero_state(batch, tf.float32)
output, _ = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
实际上,用一些实际的数字截取这个片段,我得到了一些合理的结果:
inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],
[[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],
dtype=np.int32)
lengths_ = np.asarray([3, 1], dtype=np.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_ = sess.run(output, {inputs: inputs_, lengths: lengths_})
print(output_)
结果是:
[[[ 0. 0. 0. 0. ]
[ 0.02188676 -0.01294564 0.05340237 -0.47148666]
[ 0.0343586 -0.02243731 0.0870839 -0.89869428]
[ 0. 0. 0. 0. ]]
[[ 0.00284752 -0.00315077 0.00108094 -0.99883419]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
有没有办法用动态RNN的最后一个相关输出得到形状为[batch, 1, hidden]
的三维张量?谢谢!
从以下两个来源
http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/
或者https://github.com/ageron/handson-ml/blob/master/14_recurrent_neural_networks.ipynb
显然,最后的状态可以直接从动态调用的第二个输出中提取。它将为您提供跨所有层的最后一个状态(在LSTM中,它是从LSTMStateTuple合成的),而输出包含最后层中的所有状态。
好吧-看来实际上是一个更简单的解决方案。正如@Shao Tang和@Rahul提到的,最好的方法是访问最终的细胞状态。原因如下:
tf.nn.dynamic_rnn
返回最终状态时,实际上是返回您感兴趣的最终隐藏权重。为了证明这一点,我调整了你的设置并得到了结果:GRUCell调用(rnn_cell_impl.py):
解决方案:
输出:
对于正在使用LSTMCell(另一个流行的选项)的其他读者来说,情况有点不同。LSTMCell以不同的方式维护状态-cell state是实际cell state和hidden state的元组或连接版本。因此,要访问最终隐藏权重,可以在单元格初始化期间设置(
is_state_tuple
到True
),最终状态将是一个元组:(final cell state,final hidden weights)。所以,在这种情况下_2;,(2;,h)=tf.nn.dynamic_rnn(单元,输入,长度,初始状态=单元状态)
会给你最后的重量。
参考文献: c_state and m_state in Tensorflow LSTMhttps://github.com/tensorflow/tensorflow/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/python/ops/rnn_cell_impl.py#L308https://github.com/tensorflow/tensorflow/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/python/ops/rnn_cell_impl.py#L415
这就是gather_nd的目的!
就你而言:
现在
output
是维度[batch_size, num_cells]
的张量。相关问题 更多 >
编程相关推荐