嵌套模型不在梯度带中训练

2024-05-15 00:27:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经为编码器创建了一个模型,该编码器既用于生成器,也用于鉴别器。虽然相同的编码器模型在生成器和鉴别器的总参数中显示,但它根本不训练。当以后对随机输入调用相同的编码器模型时,它将始终给出相同的结果0

tf.keras.backend.clear_session()

gamma_init = tf.keras.initializers.RandomNormal(mean=0.02, stddev=0.02) 
def ENCODER(shape=(int(456/2), int(456/2), 3)):
    y0 = tf.keras.Input(shape=shape,name = 'encoder')
    y1 = tf.keras.applications.MobileNetV3Small(include_top=False,weights=None,input_shape=shape)(y0)
    y1 = tf.keras.layers.GlobalAveragePooling2D()(y1)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(y1)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    encoding = tf.keras.layers.Dense(32)(x)
    encoding = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x,axis=1),name="L2_normalized_encodings")(encoding)
    model = tf.keras.Model(y0, encoding)
    return model

def EMBEDDER(shape = (1,),classes = NUM_CLASSES):
    input_ = tf.keras.Input(shape=shape,name = 'embedder')
    embedings_raw = tf.keras.layers.Embedding(NUM_CLASSES,32)(input_)
    embedings = tf.keras.layers.Lambda(lambda x:x[:,0])(embedings_raw)
    embedings = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x,axis=1),name="L2_normalized_embeddings")(embedings)
    return tf.keras.Model(input_,embedings)
def DISCRIMANTOR(encoder):
    generation = tf.keras.Input(shape=(int(456/2), int(456/2), 3), name='input_1_disc')
    embedings = tf.keras.Input(shape=(32,), name='input_2_disc')
    delta_layer = encoder(generation,training=True)
    ##SIGMOID OUTPUT MODULE
    modified_delta = tf.keras.layers.Add()([delta_layer,embedings])
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(modified_delta)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    x = tf.keras.layers.Dense(16)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    x = tf.keras.layers.Dense(8)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    output_layer = tf.keras.layers.Dense(1,activation = 'sigmoid')(x)
    return tf.keras.Model(inputs =[generation,embedings],outputs = [output_layer])

def GENERATOR():
    encodings = tf.keras.Input(shape=(32,), name='input_1_gen')
    embedings = tf.keras.Input(shape=(32,), name='input_2_gen')
    composed_layer = tf.keras.layers.Concatenate(axis=-1)([encodings,embedings])
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(composed_layer)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    x = tf.keras.layers.Dense(256)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    encodings_for_label = tf.keras.layers.Dense(128*2*2)(x)
    image_format = tf.keras.layers.Reshape((2,2, 128), name='de_reshape')(encodings_for_label)
    first_image = tf.keras.layers.Conv2DTranspose(filters = 256,kernel_size=(2, 2) ,strides = 2)(image_format)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(first_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    second_image = tf.keras.layers.Conv2DTranspose(filters = 256,kernel_size=(3, 3),strides = 2)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(second_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    third_image = tf.keras.layers.Conv2DTranspose(filters = 128,kernel_size=(3, 3) ,strides = 2)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(third_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    fourth_image = tf.keras.layers.Conv2DTranspose(filters = 128,kernel_size=(1, 1))(tf.keras.layers.Add()([x,tf.keras.layers.Conv2DTranspose(filters = 128,kernel_size=(7, 7),strides = 4)(first_image)]))
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(fourth_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    second_image = tf.keras.layers.Conv2DTranspose(filters = 64,kernel_size=(2, 2),strides = 3)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(second_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    third_image = tf.keras.layers.Conv2DTranspose(filters = 64,kernel_size=(2, 2) ,strides = 2)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(third_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    
    fourth_image =  tf.keras.layers.Conv2DTranspose(filters = 32,kernel_size=(1, 1))(tf.keras.layers.Add()([x,tf.keras.layers.Conv2DTranspose(filters = 64,kernel_size=(6, 6),strides = 6)(fourth_image)]))
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(fourth_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    second_image = tf.keras.layers.Conv2DTranspose(filters = 16,kernel_size=(1, 1),strides = 2)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(second_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    third_image = tf.keras.layers.Conv2DTranspose(filters = 8,kernel_size=(1, 1))(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(third_image)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    fourth_image = tf.keras.layers.Conv2DTranspose(filters = 3,kernel_size = 1,name="fourth_output")(x)
    return tf.keras.Model(inputs=[encodings,embedings],outputs = [fourth_image]) def DISCRIMANTOR(encoder):
    generation = tf.keras.Input(shape=(int(456/2), int(456/2), 3), name='input_1_disc')
    embedings = tf.keras.Input(shape=(32,), name='input_2_disc')
    delta_layer = encoder(generation,training=True)
    ##SIGMOID OUTPUT MODULE
    modified_delta = tf.keras.layers.Add()([delta_layer,embedings])
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(modified_delta)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    x = tf.keras.layers.Dense(16)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    x = tf.keras.layers.Dense(8)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.3)(x)
    output_layer = tf.keras.layers.Dense(1,activation = 'sigmoid')(x)
    return tf.keras.Model(inputs =[generation,embedings],outputs = [output_layer])

对于培训,我使用以下循环:-

input_image =tf.keras.Input(shape=(228,228, 3))
input_label =tf.keras.Input(shape=(1,))
ENC = ENCODER()
encoding = ENC(input_image)
embeding = EMBEDDER()(input_label)
image_new = GENERATOR()([encoding,embeding])
reality_check = DISCRIMANTOR(ENC)([input_image,embeding])
discriminator = tf.keras.Model(inputs = [input_image,input_label],outputs = [reality_check])   
gen_trainer = tf.keras.Model(inputs = [input_image,input_label],outputs = [encoding,image_new])


with tf.GradientTape(persistent=True) as tape:
        encoding_preds,image_ = gen_trainer(inputs,training=True)
        image_loss = GEN_IMAGE_LOSS(image,image_)
        encoding_loss = ENCODER_LOSS(encoding_preds,ENC(image_))
        
        
        binary_preds = discriminator([tf.concat([image/1.,image_],axis = 0),tf.concat([labels,labels],axis = 0)],training=True)
        binary_loss = DISC_SOFTMAX_LOSS(discrim_labels,binary_preds)
        loss_final_gen = encoding_loss + 10*image_loss + DISC_SOFTMAX_LOSS(gen_discim,binary_preds)
        loss_final_disc = binary_loss
gradients = tape.gradient(loss_final_disc, discriminator.trainable_variables)
discriminator.optimizer.apply_gradients(zip(gradients,discriminator.trainable_variables),experimental_aggregate_gradients=False)
gradients = tape.gradient(loss_final_gen, gen_trainer.trainable_variables)
gen_trainer.optimizer.apply_gradients(zip(gradients, gen_trainer.trainable_variables),experimental_aggregate_gradients=False)

具有与here相同的分布式训练功能

损失函数如下所示:

@tf.function
def DISC_EMBED_LOSS(labels, predictions):
    per_example_loss = TripletSemiHardLoss(distance_metric = pairwise_distance,reduction=tf.keras.losses.Reduction.NONE)(labels, predictions)
    return per_example_loss/ strategy.num_replicas_in_sync
@tf.function
def ENCODER_LOSS(labels,predictions):
    per_example_loss = tf.nn.compute_average_loss(1.+tf.keras.losses.CosineSimilarity(axis=-1,reduction=tf.keras.losses.Reduction.NONE)(labels,predictions), global_batch_size=GLOBAL_BATCH_SIZE)
    return per_example_loss
@tf.function
def DISC_SOFTMAX_LOSS(labels, predictions):
    per_example_loss =  tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE,from_logits=True)(labels, predictions)
    return per_example_loss/ strategy.num_replicas_in_sync
@tf.function
def GEN_IMAGE_LOSS(labels, predictions):
    per_example_loss = tf.math.reduce_mean(tf.math.abs(labels/255. - predictions/255.))
    return per_example_loss/ strategy.num_replicas_in_sync

当我在任何其他值上调用编码器时,预测值总是0,与输入或历元数无关。我认为这是因为它没有得到训练。可能是什么问题?错误是否存在于梯度带上,因为当编码器在批次上使用train_自行训练时,编码器的权重被更新。另外,请注意,生成器和鉴别器中没有任何图形断开错误

任何帮助都将不胜感激


Tags: imagealphainputinitlayerstfdropoutkeras

热门问题