我在尝试创建变分自动编码器(vae)时遇到了一个问题。我尝试实现一个类似于https://arxiv.org/pdf/1902.04601.pdf的对比VAE。我用tensorflow和keras实现了整个网络
这里的伪代码显示了算法,我的问题是我必须同时训练编码器和解码器到鉴别器D(v),但是因为我在训练期间在TC(总相关)损失中使用鉴别器,所以在训练VAE时鉴别器的张量也得到训练。所以我基本上是在寻找一种排除训练过程中的歧视因素的方法,然后分别进行训练
鉴别器是一个简单的逻辑回归分类器:
self.log_reg_classifier = tf.keras.Sequential(
layers = [tf.keras.layers.Dense(1,
input_shape=(self.z_size*2,),
activation = 'sigmoid')])
我的网络是内置的
def __call__(self, inputs, training=True):
#inputs shape (100, 2, 64, 64, 3)
# not sure why keras forces us to use the training flag
#s := target_encoder,
#z := backround_encoder,
#x := target_data,
#b := backround_data
#returns the y_pred of to loss
#compute target with target_encoder
mean_sx, logvar_sx = self.encode_mu_logvar_target(inputs[:, 0])
sx = self.sample_encoding(mean_sx, logvar_sx)
mean_and_logvar_sx = tf.concat([mean_sx, logvar_sx], axis=-1)
#compute backround with target_encoder
mean_sb, logvar_sb = self.encode_mu_logvar_target(inputs[:, 1])
sb = self.sample_encoding(mean_sb, logvar_sb)
#mean_and_logvar_sb = tf.concat([mean_sb, logvar_sb], axis=-1)
#compute target with backround_encoder
mean_zx, logvar_zx = self.encode_mu_logvar_backround(inputs[:, 0])
zx = self.sample_encoding(mean_zx, logvar_zx)
mean_and_logvar_zx = tf.concat([mean_zx, logvar_zx], axis=-1)
#compute backround with backround_encoder
mean_zb, logvar_zb = self.encode_mu_logvar_backround(inputs[:, 1])
zb = self.sample_encoding(mean_zb, logvar_zb)
mean_and_logvar_zb = tf.concat([mean_zb, logvar_zb], axis=-1)
#--------compute Lx and Lb defined in (3) (4)---------wird über Loss gemacht?
v_target = tf.concat([zx, sx], axis=-1, name = 'v_target')
v_backround = tf.zeros(tf.shape(sb))
v_backround = tf.concat([zb, tf.convert_to_tensor(v_backround, np.float32)],
axis= -1, name = 'v_backround')
y_x = self.decode(v_target)
y_b = self.decode(v_backround)
sx_zx = tf.stack([mean_and_logvar_sx, mean_and_logvar_zx])
#-----train discriminator-----
#v target = [zx1, sx1] size:100 x 100
# [zx2, sx2]
# .. ..
# [zx100, sx100]
#discriminator should classify if where the example comes from
# is from v target = 1 and if not = 0
#shuffle random sx and zx columns and concatenate them so o get the [v, v_swapped]
D_v = self.log_reg_classifier(v_target)
random_seed = tf.range(tf.shape(v_target)[0])
random_seed = tf.random.shuffle(random_seed)
D_swapped = self.log_reg_classifier(tf.gather(v_target, random_seed))
descri_input = tf.concat([D_v, D_swapped], axis = 0)
return {'reconstruction_target': y_x,
'kl_target': sx_zx, #(2,100,100)[sx,zx]
'reconstruction_backround': y_b,
'kl_backround': mean_and_logvar_zb, #müsste zb sein oder?
'TC': v_target,
'discriminator': descri_input}
该网络将接受以下方面的培训:
loss = vae.train_on_batch(x=x_batch,
y={'reconstruction_target': x_batch[:, 0],
'kl_target': blank_batch,
'reconstruction_backround': x_batch[:, 1],
'kl_backround': blank_batch,
'TC': blank_batch},
return_dict=True)#----------------------------------------
desc_loss = vae.train_on_batch(x = x_batch,
y ={' discriminator': blank_batch},
return_dict = True)
目前没有回答
相关问题 更多 >
编程相关推荐