TensorFlow仅在使用MultiRNNC时抛出错误

2024-05-16 00:27:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我在TensorFlow 1.0.1中使用传统的序列到序列框架构建一个编解码器模型。当我在编码器和解码器中有一层LSTM时,一切都正常工作。但是,当我尝试用MultiRNNCell包装的LSTMs的1层时,我在调用tf.contrib.legacy_seq2seq.rnn_decoder时出错。在

完整的错误在这篇文章的末尾,但是简单地说,它是由一行代码引起的

(c_prev, m_prev) = state

在TensorFlow中抛出TypeError: 'Tensor' object is not iterable.。我对此感到困惑,因为我传递给rnn_decoder的初始状态实际上是一个元组。据我所知,使用1层或gt;1层之间的唯一区别是后者涉及使用MultiRNNCell。在使用这个的时候,有没有一些API的问题需要我知道?在

这是我的代码(基于thisGitHub repo中的示例)。为它的长度道歉;这是我能做的最小的,同时仍然是完整的和可验证的。在

^{pr2}$

这是一个错误:

Traceback (most recent call last):
  File "example.py", line 67, in <module>
    decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 150, in rnn_decoder
    output, state = cell(inp, state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 426, in __call__
    output, res_state = self._cell(inputs, state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 655, in __call__
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 321, in __call__
    (c_prev, m_prev) = state
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 502, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

谢谢你!在


Tags: inpyhomeexampleliblinesitecell
1条回答
网友
1楼 · 发布于 2024-05-16 00:27:05

问题在于传递给seq2seq.rnn_decoder的初始状态(initial_dec_state)的格式。在

当您使用rnn.MultiRNNCell时,您正在构建一个多层递归网络,因此需要为这些层的每个层提供一个初始状态。在

因此,您应该提供一个元组的列表作为初始状态,其中列表的每个元素都是来自递归网络相应层的前一个状态。在

因此,您的initial_dec_state初始化如下:

    initial_dec_state = tuple([tf.concat([final_fw_state[-1][i],
                                      final_bw_state[-1][i]], 1) 
                           for i in range(2)])

应该是这样的:

^{pr2}$

它以以下格式创建元组列表:

    [(state_c1, state_m1), (state_c2, state_m2) ...]

更详细地说,发生'Tensor' object is not iterable.错误是因为seq2seq.rnn_decoder内部调用了您的向其传递初始状态(initial_dec_state)的{}(dec_cell)。在

rnn.MultiRNNCell.__call__遍历初始状态列表,并为每个初始状态提取元组(c_prev, m_prev)(在(c_prev, m_prev) = state语句中)。在

因此,如果只传递一个元组,rnn.MultiRNNCell.__call__将迭代它,一旦它到达(c_prev, m_prev) = state,它就会找到一个张量(应该是一个元组)作为state并抛出'Tensor' object is not iterable.错误。在

知道seq2seq.rnn_decoder期望的初始状态格式的一个好方法是调用dec_cell.zero_state(batch_size, dtype=tf.float32)。此方法以初始化所使用的递归模块所需的格式返回零填充状态张量。在

相关问题 更多 >