我如何独立培训Keras模型的不同部分?

2024-05-13 00:45:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我在尝试创建变分自动编码器(vae)时遇到了一个问题。我尝试实现一个类似于https://arxiv.org/pdf/1902.04601.pdf的对比VAE。我用tensorflow和keras实现了整个网络

Pseudocode

这里的伪代码显示了算法,我的问题是我必须同时训练编码器和解码器到鉴别器D(v),但是因为我在训练期间在TC(总相关)损失中使用鉴别器,所以在训练VAE时鉴别器的张量也得到训练。所以我基本上是在寻找一种排除训练过程中的歧视因素的方法,然后分别进行训练

鉴别器是一个简单的逻辑回归分类器:

 self.log_reg_classifier = tf.keras.Sequential(
      layers = [tf.keras.layers.Dense(1, 
        input_shape=(self.z_size*2,),
        activation = 'sigmoid')])

我的网络是内置的

def __call__(self, inputs, training=True):
    #inputs shape (100, 2, 64, 64, 3)
    # not sure why keras forces us to use the training flag
    #s := target_encoder, 
    #z := backround_encoder, 
    #x := target_data, 
    #b := backround_data
    #returns the y_pred of to loss

    #compute target with target_encoder
    mean_sx, logvar_sx = self.encode_mu_logvar_target(inputs[:, 0])
    sx = self.sample_encoding(mean_sx, logvar_sx)
    mean_and_logvar_sx = tf.concat([mean_sx, logvar_sx], axis=-1)
    
    #compute backround with target_encoder
    mean_sb, logvar_sb = self.encode_mu_logvar_target(inputs[:, 1])
    sb = self.sample_encoding(mean_sb, logvar_sb)
    #mean_and_logvar_sb = tf.concat([mean_sb, logvar_sb], axis=-1)

    #compute target with backround_encoder
    mean_zx, logvar_zx = self.encode_mu_logvar_backround(inputs[:, 0])
    zx = self.sample_encoding(mean_zx, logvar_zx)
    mean_and_logvar_zx = tf.concat([mean_zx, logvar_zx], axis=-1)

    #compute backround with backround_encoder
    mean_zb, logvar_zb = self.encode_mu_logvar_backround(inputs[:, 1])
    zb = self.sample_encoding(mean_zb, logvar_zb)
    mean_and_logvar_zb = tf.concat([mean_zb, logvar_zb], axis=-1)

    #--------compute Lx and Lb defined in (3) (4)---------wird über Loss gemacht?

    v_target = tf.concat([zx, sx], axis=-1, name = 'v_target')
    v_backround = tf.zeros(tf.shape(sb))
    v_backround = tf.concat([zb, tf.convert_to_tensor(v_backround, np.float32)],
                                 axis= -1, name = 'v_backround')
    
    y_x = self.decode(v_target)
    y_b = self.decode(v_backround)
          
    sx_zx = tf.stack([mean_and_logvar_sx, mean_and_logvar_zx])

    #-----train discriminator-----
    #v target = [zx1, sx1]  size:100 x 100
    #           [zx2, sx2]
    #             ..  ..
    #           [zx100, sx100]
    #discriminator should classify if where the example comes from
    # is from v target = 1 and if not = 0

    #shuffle random sx and zx columns and concatenate them so o get the [v, v_swapped]
    
    D_v = self.log_reg_classifier(v_target)
    
    random_seed = tf.range(tf.shape(v_target)[0])
    random_seed = tf.random.shuffle(random_seed)
    
    D_swapped = self.log_reg_classifier(tf.gather(v_target, random_seed))
    descri_input = tf.concat([D_v, D_swapped], axis = 0)
    
    return {'reconstruction_target': y_x, 
            'kl_target': sx_zx,  #(2,100,100)[sx,zx]
            'reconstruction_backround': y_b,
            'kl_backround': mean_and_logvar_zb, #müsste zb sein oder?
            'TC': v_target,
            'discriminator': descri_input}

该网络将接受以下方面的培训:

loss = vae.train_on_batch(x=x_batch,
                                  y={'reconstruction_target': x_batch[:, 0], 
                                  'kl_target': blank_batch,
                                  'reconstruction_backround': x_batch[:, 1],
                                  'kl_backround': blank_batch,
                                  'TC': blank_batch}, 
                                  return_dict=True)#----------------------------------------
        
desc_loss = vae.train_on_batch(x = x_batch,
                                    y ={' discriminator': blank_batch},
                                    return_dict = True)
        

Tags: andselftargettfbatchmeansbinputs