在GANEstim中使用对抗性损失

2024-04-26 14:12:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我现在尝试结合一级像素损失和一个对抗性损失来学习自动编码图像。代码如下。在

gan_model = tfgan.gan_model(
    generator_fn=nets.autoencoder,
    discriminator_fn=nets.discriminator,
    real_data=images,
    generator_inputs=images)

gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# Modify the loss tuple to include the pixel loss.
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss,
    weight_factor=FLAGS.weight_factor)

# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

# Run the train ops in the alternating training scheme.
tfgan.gan_train(
    train_ops,
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
    logdir=FLAGS.train_log_dir)

不过,我想使用GANEstimator来简化代码。GANEstimator的典型例子如下。在

^{pr2}$

有人知道如何在GANEstimator中使用组合的“对抗性损失”?在

谢谢。在


Tags: thedatamodeltftraingeneratoropsfn
2条回答

在您的链接中,GANEstimator具有以下参数:

 generator_loss_fn=None,
 discriminator_loss_fn=None,

generator_loss_fn应该是你的l1像素丢失。在

discriminator_loss_fn应该是你的共同对抗性损失。在

我刚刚遇到了同样的问题(这个解决方案是针对TensorFlow r1.12)。在

通读代码,^{}gan_loss元组,用联合敌方损失代替生成器损失。这意味着我们需要替换估计器中的generator_loss_fn。估计器的所有其他损失函数都有参数:gan_model, **kwargs。我们定义了自己的函数,并将其用作发电机损耗函数:

def combined_loss(gan_model, **kwargs):
    # Define non-adversarial loss - for example L1
    non_adversarial_loss = tf.losses.absolute_difference(
        gan_model.real_data, gan_model.generated_data)
    # Define generator loss
    generator_loss = tf.contrib.gan.losses.wasserstein_generator_loss(
        gan_model,
        **kwargs)
    # Combine these losses - you can specify more parameters
    # Exactly one of weight_factor and gradient_ratio must be non-None
    combined_loss = tf.contrib.gan.losses.wargs.combine_adversarial_loss(
        non_adversarial_loss,
        generator_loss,
        weight_factor=FLAGS.weight_factor,
        gradient_ratio=None,
        variables=gan_model.generator_variables,
        scalar_summaries=kwargs['add_summaries'],
        gradient_summaries=kwargs['add_summaries'])
    return combined_loss


gan_estimator = tf.contrib.gan.estimator.GANEstimator(
    model_dir,
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    generator_loss_fn=combined_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    generator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5))

有关参数的详细信息,请参阅文档:^{} 而且**kwargs与组合的对抗性损失函数不兼容,所以我使用了一个小技巧。在

相关问题 更多 >