如何将注意机制连接到tensorflow中的RNN层?

2024-05-29 10:26:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用三层编码器的输出1和注意层编码器的输出2。你知道吗

然而,我的代码首先创建一个编码器单元,然后用一个编码器单元包装注意机制,然后创建一个完整的LSTM层。你知道吗

总之,我想得到LSTM层的output1和之后通过attention层的output2。 然而,似乎一次只生成LSTM+注意层。 我可以创建单独的LSTM和attention层并将它们链接在一起吗?你知道吗

参见两个代码avsr/编码器.py和avsr/单元格.py在https://github.com/georgesterpu/Sigmedia-AVSR中。你知道吗

self._encoder_cells = build_rnn_layers(
                    cell_type=self._hparams.cell_type,
                    num_units_per_layer=self._num_units_per_layer,
                    use_dropout=self._hparams.use_dropout,
                    dropout_probability=self._hparams.dropout_probability,
                    mode=self._mode,
                    as_list=True,
                    dtype=self._hparams.dtype)

                attention_mechanism, output_attention = create_attention_mechanism(
                    attention_type=self._hparams.attention_type[0][0],
                    num_units=self._num_units_per_layer[-1],
                    memory=self._attended_memory,
                    memory_sequence_length=self._attended_memory_length,
                    mode=self._mode,
                    dtype=self._hparams.dtype
                )

                attention_cells = seq2seq.AttentionWrapper(
                    cell=self._encoder_cells[-1],
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=self._hparams.decoder_units_per_layer[-1],
                    alignment_history=self._hparams.write_attention_alignment,
                    output_attention=output_attention,
                )

                self._encoder_cells[-1] = attention_cells

                self._encoder_outputs, self._encoder_final_state = tf.nn.dynamic_rnn(
                    cell=MultiRNNCell(self._encoder_cells),
                    inputs=encoder_inputs,
                    sequence_length=self._inputs_len,
                    parallel_iterations=self._hparams.batch_size[0 if self._mode == 'train' else 1],
                    swap_memory=False,
                    dtype=self._hparams.dtype,
                    scope=scope,
                    )
def create_attention_mechanism(
        attention_type,
        num_units,
        memory,
        memory_sequence_length,
        mode,
        dtype):

    if attention_type == 'bahdanau':
        attention_mechanism = seq2seq.BahdanauAttention(
            num_units=num_units,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            normalize=False,
            dtype=dtype,
        )
        output_attention = False

Tags: selflayerencodermodetype编码器lengthnum

热门问题