我正在用Keras编写关于vqvae的代码。但是,我遇到了一个奇怪的情况,如果实现没有自定义层的编码器层,它工作得很好。但是,一旦我把相同的代码放在自定义层中,就会得到不同的结果。你知道吗
我已经写了一个自定义的编码器层,它工作得很好。两个自定义层都是在Keras文档的同一个模板中编写的。你知道吗
def residual_stack(h, num_hiddens, num_residual_layers, num_residual_hiddens):
for i in range(num_residual_layers):
h_i = layers.ReLU()(h)
h_i = layers.Conv2D(filters=num_residual_hiddens, kernel_size=(3, 3), strides=(1, 1), padding="same")(h_i)
h_i = layers.ReLU()(h_i)
h_i = layers.Conv2D(filters=num_hiddens, kernel_size=(1, 1), strides=(1, 1), padding="same")(h_i)
h = layers.Add()([h, h_i])
h = layers.ReLU()(h)
return h
class Encoder(layers.Layer):
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, **kwargs):
self.num_hiddens = num_hiddens
self.num_residual_layers = num_residual_layers
self.num_residual_hiddens = num_residual_hiddens
super(Encoder, self).__init__(**kwargs)
def call(self, x):
h = layers.Conv2D(
filters=int(self.num_hiddens / 2),
kernel_size=(4, 4),
strides=(2, 2),
padding="same")(x)
h = layers.ReLU()(h)
h = layers.Conv2D(
filters=self.num_hiddens,
kernel_size=(4, 4),
strides=(2, 2),
padding="same")(h)
h = layers.ReLU()(h)
h = layers.Conv2D(
filters=self.num_hiddens,
kernel_size=(3, 3),
strides=(1, 1),
padding="same")(h)
h = residual_stack(
h,
self.num_hiddens,
self.num_residual_layers,
self.num_residual_hiddens)
return h
def compute_output_shape(self, input_shape):
space = input_shape[1:-1]
new_space = []
for i in range(len(space)):
new_space.append(int(space[i] // 4))
return (input_shape[0],) + tuple(new_space) + (self.num_hiddens,)
class Decoder(layers.Layer):
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, **kwargs):
self.num_hiddens = num_hiddens
self.num_residual_layers = num_residual_layers
self.num_residual_hiddens = num_residual_hiddens
super(Decoder, self).__init__(**kwargs)
def call(self, x):
h = layers.Conv2D(
filters=self.num_hiddens,
kernel_size=(3, 3),
strides=(1, 1),
padding="same")(x)
h = residual_stack(
h,
self.num_hiddens,
self.num_residual_layers,
self.num_residual_hiddens)
h = layers.Conv2DTranspose(
filters=int(self.num_hiddens / 2),
kernel_size=(4, 4),
strides=(2, 2),
padding="same")(h)
h = layers.ReLU()(h)
x_recon = layers.Conv2DTranspose(
filters=1,
kernel_size=(4, 4),
strides=(2, 2),
padding="same")(h)
return x_recon
def compute_output_shape(self, input_shape):
space = input_shape[1:-1]
new_space = []
for i in range(len(space)):
new_space.append(int(space[i] * 4))
return (input_shape[0],) + tuple(new_space) + (1,)
ef build_vqvae(x_train, input_shape, embedding_dim=64, num_embeddings=128, commitment_cost=0.25):
'''
VQ-VAE Hyper Parameters.
embedding_dim = 64 # Length of embedding vectors.
num_embeddings = 128 # Number of embedding vectors (high value = high bottleneck capacity).
commitment_cost = 0.25 # Controls the weighting of the loss terms.
'''
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
assert(x_train.ndim == 4)
# Build modules
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = layers.Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), padding="same", name="to_vq")
input_img = Input(shape=input_shape)
h = encoder(input_img)
# VQVAELayer.
enc = Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), name="pre_vqvae")(h)
enc_inputs = enc
enc, perplexity = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, name="vqvae")(enc)
x = Lambda(lambda enc: enc_inputs + K.stop_gradient(enc - enc_inputs), name="encoded")(enc)
data_variance = np.var(x_train)
loss = vq_vae_loss_wrapper(data_variance, commitment_cost, enc, enc_inputs)
perplexity = metrices_perplexity_wrapper(perplexity)
x_recon = decoder(x)
vqvae = Model(input_img, x_recon)
50个时代之后,模特什么也没学到:
mean_squared_error: 0.0013 - metrices: 0.0078
9728/13090 [=====================>........] - ETA: 3s - loss: 0.3939 - mean_squared_error: 0.0013 - metrices: 0.0078
9856/13090 [=====================>........] - ETA: 2s - loss: 0.3940 - mean_squared_error: 0.0013 - metrices: 0.0078
9984/13090 [=====================>........] - ETA: 2s - loss: 0.3938 - mean_squared_error: 0.0013 - metrices: 0.0078
10112/13090 [======================>.......] - ETA: 2s - loss: 0.3937 - mean_squared_error: 0.0013 - metrices: 0.0078
10240/13090 [======================>.......] - ETA: 2s - loss: 0.3935 - mean_squared_error: 0.0013 - metrices: 0.0078
10368/13090 [======================>.......] - ETA: 2s - loss: 0.3939 - mean_squared_error: 0.0013 - metrices: 0.0078
10496/13090 [=======================>......] - ETA: 2s - loss: 0.3956 - mean_squared_error: 0.0013 - metrices: 0.0078
10624/13090 [=======================>......] - ETA: 2s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
10752/13090 [=======================>......] - ETA: 2s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
10880/13090 [=======================>......] - ETA: 2s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
11008/13090 [========================>.....] - ETA: 1s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
11136/13090 [========================>.....] - ETA: 1s - loss: 0.3951 - mean_squared_error: 0.0013 - metrices: 0.0078
11264/13090 [========================>.....] - ETA: 1s - loss: 0.3951 - mean_squared_error: 0.0013 - metrices: 0.0078
11392/13090 [=========================>....] - ETA: 1s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
11520/13090 [=========================>....] - ETA: 1s - loss: 0.3956 - mean_squared_error: 0.0013 - metrices: 0.0078
11648/13090 [=========================>....] - ETA: 1s - loss: 0.3957 - mean_squared_error: 0.0013 - metrices: 0.0078
11776/13090 [=========================>....] - ETA: 1s - loss: 0.3957 - mean_squared_error: 0.0013 - metrices: 0.0078
11904/13090 [==========================>...] - ETA: 1s - loss: 0.3973 - mean_squared_error: 0.0013 - metrices: 0.0078
12032/13090 [==========================>...] - ETA: 0s - loss: 0.3970 - mean_squared_error: 0.0013 - metrices: 0.0078
12160/13090 [==========================>...] - ETA: 0s - loss: 0.3978 - mean_squared_error: 0.0013 - metrices: 0.0078
12288/13090 [===========================>..] - ETA: 0s - loss: 0.3978 - mean_squared_error: 0.0013 - metrices: 0.0078
12416/13090 [===========================>..] - ETA: 0s - loss: 0.3973 - mean_squared_error: 0.0013 - metrices: 0.0078
12544/13090 [===========================>..] - ETA: 0s - loss: 0.3972 - mean_squared_error: 0.0013 - metrices: 0.0078
12672/13090 [============================>.] - ETA: 0s - loss: 0.3969 - mean_squared_error: 0.0013 - metrices: 0.0078
12800/13090 [============================>.] - ETA: 0s - loss: 0.3970 - mean_squared_error: 0.0013 - metrices: 0.0078
12928/13090 [============================>.] - ETA: 0s - loss: 0.3967 - mean_squared_error: 0.0013 - metrices: 0.0078
13056/13090 [============================>.] - ETA: 0s - loss: 0.3964 - mean_squared_error: 0.0013 - metrices: 0.0078
13090/13090 [==============================] - 13s 966us/step - loss: 0.3965 - mean_squared_error: 0.0013 - metrices: 0.0078 - val_loss: 0.4074 - val_mean_squared_error: 0.0013 - val_metrices: 0.0078
但是,如果我像这样将代码块放在build\u vae()方法中,效果会很好。这是怎么发生的?你知道吗
def build_vqvae(x_train, input_shape, embedding_dim=64, num_embeddings=128, commitment_cost=0.25):
'''
VQ-VAE Hyper Parameters.
embedding_dim = 64 # Length of embedding vectors.
num_embeddings = 128 # Number of embedding vectors (high value = high bottleneck capacity).
commitment_cost = 0.25 # Controls the weighting of the loss terms.
'''
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
assert(x_train.ndim == 4)
# Build modules
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = layers.Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), padding="same", name="to_vq")
input_img = Input(shape=input_shape)
h = encoder(input_img)
# VQVAELayer.
enc = Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), name="pre_vqvae")(h)
enc_inputs = enc
enc, perplexity = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, name="vqvae")(enc)
x = Lambda(lambda enc: enc_inputs + K.stop_gradient(enc - enc_inputs), name="encoded")(enc)
data_variance = np.var(x_train)
loss = vq_vae_loss_wrapper(data_variance, commitment_cost, enc, enc_inputs)
perplexity = metrices_perplexity_wrapper(perplexity)
h = layers.Conv2D(
filters=num_hiddens,
kernel_size=(3, 3),
strides=(1, 1),
padding="same")(x)
h = residual_stack(
h,
num_hiddens,
num_residual_layers,
num_residual_hiddens)
h = layers.Conv2DTranspose(
filters=int(num_hiddens / 2),
kernel_size=(4, 4),
strides=(2, 2),
padding="same")(h)
h = layers.ReLU()(h)
x_recon = layers.Conv2DTranspose(
filters=1,
kernel_size=(4, 4),
strides=(2, 2),
padding="same")(h)
vqvae = Model(input_img, x_recon)
return vqvae, loss, perplexity
结果只适用于少数几个时期的培训:
4224/13090 [========>.....................] - ETA: 10s - loss: 0.0413 - mean_squared_error: 1.4056e-04 - metrices: 0.0078
4352/13090 [========>.....................] - ETA: 10s - loss: 0.0412 - mean_squared_error: 1.4039e-04 - metrices: 0.0078
4480/13090 [=========>....................] - ETA: 10s - loss: 0.0413 - mean_squared_error: 1.4049e-04 - metrices: 0.0078
4608/13090 [=========>....................] - ETA: 9s - loss: 0.0411 - mean_squared_error: 1.3994e-04 - metrices: 0.0078
4736/13090 [=========>....................] - ETA: 9s - loss: 0.0409 - mean_squared_error: 1.3917e-04 - metrices: 0.0078
4864/13090 [==========>...................] - ETA: 9s - loss: 0.0407 - mean_squared_error: 1.3861e-04 - metrices: 0.0078
4992/13090 [==========>...................] - ETA: 9s - loss: 0.0406 - mean_squared_error: 1.3821e-04 - metrices: 0.0078
5120/13090 [==========>...................] - ETA: 9s - loss: 0.0405 - mean_squared_error: 1.3785e-04 - metrices: 0.0078
5248/13090 [===========>..................] - ETA: 9s - loss: 0.0403 - mean_squared_error: 1.3733e-04 - metrices: 0.0078
5376/13090 [===========>..................] - ETA: 8s - loss: 0.0401 - mean_squared_error: 1.3672e-04 - metrices: 0.0078
5504/13090 [===========>..................] - ETA: 8s - loss: 0.0401 - mean_squared_error: 1.3657e-04 - metrices: 0.0078
5632/13090 [===========>..................] - ETA: 8s - loss: 0.0401 - mean_squared_error: 1.3655e-04 - metrices: 0.0078
5760/13090 [============>.................] - ETA: 8s - loss: 0.0400 - mean_squared_error: 1.3619e-04 - metrices: 0.0078
5888/13090 [============>.................] - ETA: 8s - loss: 0.0400 - mean_squared_error: 1.3606e-04 - metrices: 0.0078
6016/13090 [============>.................] - ETA: 8s - loss: 0.0399 - mean_squared_error: 1.3579e-04 - metrices: 0.0078
6144/13090 [=============>................] - ETA: 8s - loss: 0.0398 - mean_squared_error: 1.3544e-04 - metrices: 0.0078
6272/13090 [=============>................] - ETA: 7s - loss: 0.0397 - mean_squared_error: 1.3532e-04 - metrices: 0.0078
6400/13090 [=============>................] - ETA: 7s - loss: 0.0397 - mean_squared_error: 1.3505e-04 - metrices: 0.0078
6528/13090 [=============>................] - ETA: 7s - loss: 0.0403 - mean_squared_error: 1.3727e-04 - metrices: 0.0078
6656/13090 [==============>...............] - ETA: 7s - loss: 0.0401 - mean_squared_error: 1.3673e-04 - metrices: 0.0078
6784/13090 [==============>...............] - ETA: 7s - loss: 0.0401 - mean_squared_error: 1.3641e-04 - metrices: 0.0078
6912/13090 [==============>...............] - ETA: 7s - loss: 0.0403 - mean_squared_error: 1.3722e-04 - metrices: 0.0078
7040/13090 [===============>..............] - ETA: 7s - loss: 0.0402 - mean_squared_error: 1.3697e-04 - metrices: 0.0078
7168/13090 [===============>..............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3670e-04 - metrices: 0.0078
7296/13090 [===============>..............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3665e-04 - metrices: 0.0078
7424/13090 [================>.............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3666e-04 - metrices: 0.0078
7552/13090 [================>.............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3655e-04 - metrices: 0.0078
目前没有回答
相关问题 更多 >
编程相关推荐