为什么我的模型可以使用`tf.GradientTape()`但在使用`keras.models.model.fit()时失败`

2024-04-27 20:04:39 发布

您现在位置:Python中文网/ 问答频道 /正文

经过大量的努力,我成功地构建了一个现有的pytorch风格的转移项目的tensorflow 2实现。然后我想通过Keras标准学习获得所有额外的特性,例如model.fit()

但同样的模式在通过model.fit()学习时失败。模型似乎学习了内容特性,但无法学习样式特性。这是quesion中模型的示意图:

enter image description here

def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
  """ creates a VGG model that returns output values for the given layers
  see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
  Returns: 
    function(x, preprocess=True):
      Args: 
        x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
      Returns:
        a tuple of lists, ([content_features], [style_features])
  usage:
    (content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
  """
  preprocessingFn = tf.keras.applications.vgg19.preprocess_input
  base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
  base_model.trainable = False
  content_features = [base_model.get_layer(name).output for name in content_layers]
  style_features = [base_model.get_layer(name).output for name in style_layers]
  output_features = content_features + style_features

  model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
  model.trainable = False

  def _get_features(x, preprocess=True):
    """
    Args:
      x: expecting tensor, domain=255. hwcRGB
    """
    if preprocess and callable(preprocessingFn): 
      x = preprocessingFn(x)
    output = model(x) # call as tf.keras.Layer()
    return ( output[:len(content_layers)], output[len(content_layers):] )

  return _get_features 



class VGG_Features():
""" get content and style features from VGG model """
  def __init__(self, loss_model, style_image=None, target_style_gram=None):
    self.loss_model = loss_model
    if style_image is not None:
      assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
      self.style_image = style_image
      self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
    if target_style_gram is not None:
      self.target_style_gram = target_style_gram

  @staticmethod
  def get_style_gram(vgg_features_model, style_image):
    style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
    # show([style_image], w=128, domain=(0.,255.) )

    # B, H, W, C = style_batch.shape
    (_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
    target_style_gram = [ fnstf_utils.gram(value)  for value in style_features ]  # list
    return target_style_gram  

  def __call__(self, input_batch):
    content_features, style_features = self.loss_model( input_batch, preprocess=True )
    style_gram = tuple(fnstf_utils.gram(value)  for value in style_features)  # tuple(<generator>)
    return (content_features[0],) + style_gram  # tuple = tuple + tuple




class TransformerNetwork_VGG(tf.keras.Model):
  def __init__(self, transformer=transformer, vgg_features=vgg_features):
    super(TransformerNetwork_VGG, self).__init__()
    self.transformer = transformer 
    # type: tf.keras.models.Model
    # input_shapes:  (None, 256,256,3)
    # output_shapes: (None, 256,256,3)


    style_model = {
       'content_layers':['block5_conv2'],
       'style_layers': ['block1_conv1',
                  'block2_conv1',
                  'block3_conv1', 
                  'block4_conv1', 
                  'block5_conv1']
    }
    vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )

    self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size) 

    # input_shapes:  (None, 256,256,3)
    # output_shapes: [(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    #                [ content_loss,        style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]


  def call(self, inputs):
    x = inputs                # shape=(None, 256,256,3)

    # shape=(None, 256,256,3)
    generated_image = self.transformer(x)                    

    # shape=[(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    vgg_feature_losses = self.vgg(generated_image)           

    return vgg_feature_losses       # tuple(content1, style1, style2, style3, style4, style5)

风格形象 style image

特征权重=[1.0,1.0,1.0,1.0,1.0,1.0,1.0]

GradientTape学习

使用tf.GradientTape()循环,我手动处理来自TransformerNetwork_VGG(x_train)的多个输出,例如6个张量的元组。这种方法学习正确

  @tf.function()
  def train_step(x_train, y_true, loss_weights=None, log_freq=10):
    with tf.GradientTape() as tape:
      y_pred = TransformerNetwork_VGG(x_train)
      generated_content_features = y_pred[:1]
      generated_style_gram = y_pred[1:]


      y_true = TransformerNetwork_VGG.vgg(x_train)
      target_content_features = y_true[:1]
      target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram

      content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
      style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))

      total_loss = content_loss + = tf.reduce_sum(style_loss)
      TransformerNetwork = TransformerNetwork_VGG.transformer
      grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
      optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5: 
# losses:             [   6078.71         70.23  4495.13 13817.65 88217.99    48.36]

gradient tape

model.fit()学习

使用tf.keras.models.Model.fit(),多个输出(例如6个张量的元组)分别作为loss(y_pred, y_true)输入损失函数,然后乘以reduction上的正确权重。此方法确实可以学习近似内容图像,但是无法学习以最小化样式损失!我不知道为什么

  history = TransformerNetwork_VGG.fit(
    x=train_dataset.repeat(NUM_EPOCHS),
    epochs=NUM_EPOCHS,
    steps_per_epoch=NUM_BATCHES,
    callbacks=callbacks,
  )
# model.fit() epoch=5: 
# losses:             [  4661.08       219.95   6959.01   4897.39 209201.16     84.68]]

model-fit

50 epochs, with boosted style_weights, FEATURE_WEIGHTS= [ 0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # boost style loss x100

model-fit after 50

step=50, losses=[269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]

检查mse损失*重量

我测试了我的模型,损失和重量固定如下 *特征权重=序列=[1,2,3,4,5,6.,] *MSELoss(y_true,y_pred)==tf.ones()具有相同形状 并确认model.fit()正确处理了多个输出损耗*权重

losses as ones

我已经检查了我能想到的所有东西,但是我不知道如何使用model.fit()正确地学习模型。我错过了什么

完整的笔记本可在此处找到:https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb


Tags: imageselfnonetargetmodelstylelayerstf