TensorFlow:从RNN中获取所有状态

2024-04-26 11:43:42 发布

您现在位置:Python中文网/ 问答频道 /正文

如何从TensorFlow中的tf.nn.rnn()或{}获得所有隐藏状态?API只提供最终状态。在

第一种选择是在构建一个直接在rncell上运行的模型时编写一个循环。但是,时间步数对我来说不是固定的,而是取决于传入的批处理。在

有些选项是使用GRU或编写自己的RNNCell,将状态连接到输出。前者的选择不够普遍,后者听起来太老套。在

另一个选择是执行类似the answers in this question的操作,从RNN获取所有变量。但是,我不确定如何在这里以标准方式将隐藏状态与其他变量分开。在

有没有一种好方法可以在使用库提供的rnnapi的同时从RNN获取所有隐藏状态?在


Tags: 模型api状态tftensorflow选项时间nn
2条回答

我已经创建了一个PR here,它可能会帮助您处理简单的案例

让我简单地解释一下我的实现,这样您就可以根据需要编写自己的版本。主要是修改_time_step函数:

def _time_step(time, output_ta_t, state, *args):

除了传入额外的*args,参数保持不变。但是为什么args?因为我想支持tensorflow的习惯行为。您只能通过忽略args参数返回最终状态:

^{pr2}$

如何利用它?在

if args:
    args = tuple(
        ta.write(time, out) for ta, out in zip(args[0], [new_state])
    )

实际上,这只是对以下(原始)代码的修改:

output_ta_t = tuple(
    ta.write(time, out) for ta, out in zip(output_ta_t, output)
)

现在,args应该包含您想要的所有状态。在

完成以上所有工作后,您可以使用以下代码获取状态(或最终状态):

_, output_final_ta, *state_info = control_flow_ops.while_loop( ...

以及

if states_ta is not None:
    final_state, states_final_ta = state_info
else:
    final_state, states_final_ta = state_info[0], None

虽然我没有在复杂的情况下测试它,但它应该在“简单”的条件下工作(here's我的测试用例)

在tf.nn.动态(同时tf.nn.静态)有两个返回值;“outputs”,“state”(https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn

正如您所说,“state”是RNN的最终状态,但是“outputs”都是RNN的隐藏状态(形状是[batch\u size,max_time,cell.output_大小])在

您可以使用“outputs”作为RNN的隐藏状态,因为在大多数库提供的RNNCell中,“output”和“state”是相同的。(LSTMCell除外)

相关问题 更多 >