如何在tensorflow transformer模型中使用验证数据splt?

2024-05-28 18:24:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在我自己的数据上使用TensorFlow transformer model for language understanding。我注意到,在开始时,他们将数据分为train、val和test。然后,他们从不使用val或train,而是手动输入要翻译的句子来“评估”模型。我试图通过输入以下内容使模型使用验证拆分:

@tf.function()
def val_step(vinp):

  encoder_input = vinp
  start, end = tokenizers.en.tokenize([''])[0]
  output = tf.convert_to_tensor([start])
  output = tf.expand_dims(output, 0)

  for i in range(40):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)

    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input,
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)

    # select the last word from the seq_len dimension
    predictions = predictions[:, -1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.argmax(predictions, axis=-1)

    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

    # return the result if the predicted_id is equal to the end token
    if predicted_id == end:
      break

  # output.shape (1, tokens)
  text = tokenizers.en.detokenize(output)[0]  # shape: ()

  tokens = tokenizers.en.lookup(output)[0]

  return text, tokens, attention_weights

然后,我将for循环添加到下一个代码块:

  for (batch, (inp, tar)) in enumerate(train_batches):
    tar = tf.squeeze(tf.transpose(tar, [0, 2, 1]), -1)
    train_step(inp, tar)

    if batch % 50 == 0:
      print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')
  
  for (batch, (vinp, vtar)) in enumerate(val_batches):
    val_step(vinp)

这给了我OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

是否知道出现此错误的原因/实现验证数据的其他方法


Tags: thetoidforinputoutputtfbatch

热门问题