tfa.seq2seq.BeamSearchDecoder输出空时间步长输出

2024-05-15 04:58:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用tensorflow插件构建tensorflow 2中的编码器-解码器模型。对于预测,我尝试使用tfa.seq2seq.BeamSearchDecoder类。不幸的是,我得到了一个输出中没有时间步长的张量

final_outputs:  (64, 5, None)
outputs:  (64, None)

错误很可能发生在我的decoder类的call方法中

def call(self, dec_input, enc_output, enc_hidden, start_token=1, end_token=2, training=False):
     start_tokens = tf.fill([self.batch_sz], start_token)

     enc_out = tfa.seq2seq.tile_batch(enc_output, multiplier=self.beam_width) # beam_with * [batch_size, max_length_input, rnn_units]
     self.attention_mechanism.setup_memory(enc_out)

     # set decoder_inital_state which is an AttentionWrapperState considering beam_width
     hidden_state = tfa.seq2seq.tile_batch(enc_hidden, multiplier=self.beam_width)
     decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=self.beam_width*self.batch_sz, dtype=tf.float32)
     decoder_initial_state = decoder_initial_state.clone(
     cell_state=hidden_state)

     # Instantiate BeamSearchDecoder
     decoder_instance = tfa.seq2seq.BeamSearchDecoder(self.rnn_cell, beam_width=self.beam_width, output_layer=self.fc)
     decoder_embedding_matrix = self.embd_layer.variables[0]
     print("decoder_embedding_matrix: ", decoder_embedding_matrix.shape)

     # The BeamSearchDecoder object's call() function takes care of everything.
     outputs, _, _ = decoder_instance(
     decoder_embedding_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state)
            
     final_outputs = tf.transpose(outputs.predicted_ids, perm=(0, 2, 1))
            
     outputs = final_outputs[:,0,:]  # [batch, length]
     return outputs

如果您需要任何进一步的信息或有关于如何解决此问题的想法,请告诉我

谢谢,祝你今天愉快

瑟伦


Tags: selftokenbatchwidthoutputsstarthiddeninitial

热门问题