张量流模型中具有高精度的超高交叉熵损失值

2024-04-18 07:37:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用Tensorflow训练一个模型。我按照一些教程构建自己的模型,如下所示

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_mse_loss = tf.keras.metrics.Mean(name='train_mse_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_mse_loss = tf.keras.metrics.Mean(name='test_mse_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

optimizer = optimizer = tf.keras.optimizers.Adam(0.00001)

@tf.function
def train_step(datas, labels):
  with tf.GradientTape() as tape:
      vq_input, vq_weight_value, output = vae(datas)

      entropy_loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, output))
      mse_losses = K.mean((tf.stop_gradient(vq_input) - vq_weight_value)**2)
      mse_losses_2 = K.mean((vq_input - tf.stop_gradient(vq_weight_value))**2)
      total_loss = entropy_loss + mse_losses + mse_losses_2

  grads = tape.gradient(total_loss, vae.trainable_variables)
  optimizer.apply_gradients(zip(grads, vae.trainable_variables))
  train_loss(total_loss)
  train_accuracy(labels, output)
  train_mse_loss(mse_losses + mse_losses_2)

@tf.function
def test_step(datas, labels):
  vq_input, vq_weight_value, output = vae(datas)

  entropy_loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, output))
  mse_losses = K.mean((tf.stop_gradient(vq_input) - vq_weight_value)**2)
  mse_losses_2 = K.mean((vq_input - tf.stop_gradient(vq_weight_value))**2)
  total_loss = entropy_loss + mse_losses + mse_losses_2

  test_loss(total_loss)
  test_accuracy(labels, output)
  test_mse_loss(mse_losses + mse_losses_2)

train_ds = tf.data.Dataset.from_tensor_slices((train_seg_data, train_seg_label)).batch(39)
test_ds = tf.data.Dataset.from_tensor_slices((test_seg_data, test_seg_label)).batch(39)

EPOCHS = 9999

for epoch in range(EPOCHS):
  # Reset the states at each new epoch
  train_loss.reset_states()
  train_accuracy.reset_states()
  train_mse_loss.reset_states()
  test_loss.reset_states()
  test_accuracy.reset_states()
  test_mse_loss.reset_states()

  for train_datass, train_labelss in train_ds:
    train_step(train_datass, train_labelss)

  for test_datass, test_labelss in test_ds:
    test_step(test_datass, test_labelss)

  template = 'Epoch {}, Loss: {}, mse Loss: {}, Accuracy: {}, Test Loss: {}, Test mse Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
      train_loss.result(),
      train_mse_loss.result(),
      train_accuracy.result(),

      test_loss.result(),
      test_mse_loss.result(),
      test_accuracy.result(),))

但我发现验证损失值看起来超高,验证精度也在不断提高。谁能告诉我为什么结果看起来如此奇怪

enter image description hereenter image description here


Tags: nametestinputlabelstftrainkerasmetrics