概念性理解tf.nn.动态()“输出”与“状态”

2024-04-20 09:52:34 发布

您现在位置:Python中文网/ 问答频道 /正文

上下文

我正在阅读Hands on ML的第二部分,并希望在RNN的损失计算中何时使用“输出”和何时使用“状态”有一点清晰。你知道吗

在这本书中(第396页对于那些有这本书的人),作者说,“注意,完全连接的层连接到states张量,它只包含RNN的最终状态,”指的是一个序列分类器,它展开了28个步骤。由于states变量将具有len(states) == <number_of_hidden_layers>,因此在构建深层RNN时,我一直使用状态[-1]来仅连接到最后一层的最终状态。例如:

# hidden_layer_architecture = list of ints defining n_neurons in each layer
# example: hidden_layer_architecture = [100 for _ in range(5)]
layers = []
for layer_id, n_neurons in enumerate(hidden_layer_architecture):

    hidden_layer = tf.contrib.rnn.BasicRNNCell(n_neurons, 
                                               activation=tf.nn.tanh,                                                                                                                                                                     
                                               name=f'hidden_layer_{layer_id}')

    layers.append(hidden_layer)

recurrent_hidden_layers = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(recurrent_hidden_layers,
                                    X_, dtype=tf.float32)

logits = tf.layers.dense(states[-1], n_outputs, name='outputs')

鉴于作者先前的陈述,这一点与预期一样有效。但是,我不明白什么时候会使用outputs变量(第一个tf.nn.dynamic_rnn()输出)

我已经研究了this question,它在回答细节方面做得很好,并且提到,“如果您只对单元格的最后一个输出感兴趣,那么您只需对时间维度进行切片,就可以选择最后一个元素(例如outputs[:, -1, :])。”我推断这意味着states[-1] == outputs[:, -1, :]的一些东西,测试时是错误的。为什么不是这样?如果输出是单元格在每个时间步的输出,为什么不是这样呢?一般来说。。。你知道吗

问题

何时使用loss函数中来自tf.nn.dynamic_rnn()outputs变量,何时使用states变量?这将如何改变网络的抽象架构?你知道吗

请澄清。你知道吗


Tags: inlayer状态layerstfdynamic作者nn
1条回答
网友
1楼 · 发布于 2024-04-20 09:52:34

这基本上把它分解了:

outputs:RNN顶层输出的完整序列。这意味着,如果您使用MultiRNNCell,这将仅是顶部单元格;下面的单元格中没有任何内容。
一般来说,对于定制的RNNCell实现,这几乎可以是任何东西,但是这里几乎所有的标准单元格都返回状态序列,但是您也可以自己编写一个定制单元格,在将其作为输出返回之前对状态序列执行某些操作(例如线性变换)。你知道吗

state(注意,这就是文档所称的,不是states)是最后一个时间步的完整状态。一个重要的区别是,在MultiRNNCell的情况下,这将包含序列中所有细胞的最终状态,而不仅仅是顶部的状态!此外,此输出的精确格式/类型因所使用的RNNCell而有很大差异(例如,它可以是张量,或张量元组…)。你知道吗

因此,如果你所关心的只是MultiRNNCell中最后一个时间步骤的最高状态,那么你真的有两个应该相同的选项,归结到个人偏好/“清晰性”:

  • outputs[:, -1, :](假设批处理主要格式)只从顶级状态序列中提取最后一个时间步。你知道吗
  • state[-1]只从所有层的最终状态元组中提取顶级状态。你知道吗

在其他情况下,您可能没有此选择:

  • 如果您确实需要完整的序列输出,则需要使用outputs。你知道吗
  • 如果需要MultiRNNCell中较低层的最终状态,则需要使用state。你知道吗

至于相等性检查失败的原因:如果您实际使用==,我相信这会检查明显不同的张量对象的相等性。相反,您可以尝试检查两个对象的,以了解一些简单的玩具场景(微小的状态大小/序列长度),它们应该是相同的。你知道吗

相关问题 更多 >