实现VQVAE时Keras自定义层的问题

2024-04-28 19:33:55 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用Keras编写关于vqvae的代码。但是,我遇到了一个奇怪的情况,如果实现没有自定义层的编码器层,它工作得很好。但是,一旦我把相同的代码放在自定义层中,就会得到不同的结果。你知道吗

我已经写了一个自定义的编码器层,它工作得很好。两个自定义层都是在Keras文档的同一个模板中编写的。你知道吗

def residual_stack(h, num_hiddens, num_residual_layers, num_residual_hiddens):
    for i in range(num_residual_layers):
        h_i = layers.ReLU()(h)
        h_i = layers.Conv2D(filters=num_residual_hiddens, kernel_size=(3, 3), strides=(1, 1), padding="same")(h_i)
        h_i = layers.ReLU()(h_i)
        h_i = layers.Conv2D(filters=num_hiddens, kernel_size=(1, 1), strides=(1, 1), padding="same")(h_i)
        h = layers.Add()([h, h_i])
        h = layers.ReLU()(h)
    return h


class Encoder(layers.Layer):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, **kwargs):
        self.num_hiddens = num_hiddens
        self.num_residual_layers = num_residual_layers
        self.num_residual_hiddens = num_residual_hiddens
        super(Encoder, self).__init__(**kwargs)

    def call(self, x):
        h = layers.Conv2D(
            filters=int(self.num_hiddens / 2),
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same")(x)
        h = layers.ReLU()(h)

        h = layers.Conv2D(
            filters=self.num_hiddens,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same")(h)
        h = layers.ReLU()(h)

        h = layers.Conv2D(
            filters=self.num_hiddens,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding="same")(h)

        h = residual_stack(
            h,
            self.num_hiddens,
            self.num_residual_layers,
            self.num_residual_hiddens)
        return h

    def compute_output_shape(self, input_shape):
        space = input_shape[1:-1]
        new_space = []
        for i in range(len(space)):
            new_space.append(int(space[i] // 4))
        return (input_shape[0],) + tuple(new_space) + (self.num_hiddens,)


class Decoder(layers.Layer):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, **kwargs):
        self.num_hiddens = num_hiddens
        self.num_residual_layers = num_residual_layers
        self.num_residual_hiddens = num_residual_hiddens
        super(Decoder, self).__init__(**kwargs)

    def call(self, x):
        h = layers.Conv2D(
            filters=self.num_hiddens,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding="same")(x)

        h = residual_stack(
            h,
            self.num_hiddens,
            self.num_residual_layers,
            self.num_residual_hiddens)

        h = layers.Conv2DTranspose(
            filters=int(self.num_hiddens / 2),
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same")(h)
        h = layers.ReLU()(h)

        x_recon = layers.Conv2DTranspose(
            filters=1,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same")(h)

        return x_recon

    def compute_output_shape(self, input_shape):
        space = input_shape[1:-1]
        new_space = []
        for i in range(len(space)):
            new_space.append(int(space[i] * 4))
        return (input_shape[0],) + tuple(new_space) + (1,)

ef build_vqvae(x_train, input_shape, embedding_dim=64, num_embeddings=128, commitment_cost=0.25):
    '''
    VQ-VAE Hyper Parameters.
    embedding_dim = 64 # Length of embedding vectors.
    num_embeddings = 128 # Number of embedding vectors (high value = high bottleneck capacity).
    commitment_cost = 0.25 # Controls the weighting of the loss terms.
    '''
    num_hiddens = 128
    num_residual_hiddens = 32
    num_residual_layers = 2
    assert(x_train.ndim == 4)
    # Build modules
    encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
    decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
    pre_vq_conv1 = layers.Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), padding="same", name="to_vq")
    input_img = Input(shape=input_shape)
    h = encoder(input_img)

    # VQVAELayer.
    enc = Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), name="pre_vqvae")(h)
    enc_inputs = enc
    enc, perplexity = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, name="vqvae")(enc)
    x = Lambda(lambda enc: enc_inputs + K.stop_gradient(enc - enc_inputs), name="encoded")(enc)
    data_variance = np.var(x_train)
    loss = vq_vae_loss_wrapper(data_variance, commitment_cost, enc, enc_inputs)
    perplexity = metrices_perplexity_wrapper(perplexity)

    x_recon = decoder(x)
    vqvae = Model(input_img, x_recon)

50个时代之后,模特什么也没学到:

mean_squared_error: 0.0013 - metrices: 0.0078
 9728/13090 [=====================>........] - ETA: 3s - loss: 0.3939 - mean_squared_error: 0.0013 - metrices: 0.0078
 9856/13090 [=====================>........] - ETA: 2s - loss: 0.3940 - mean_squared_error: 0.0013 - metrices: 0.0078
 9984/13090 [=====================>........] - ETA: 2s - loss: 0.3938 - mean_squared_error: 0.0013 - metrices: 0.0078
10112/13090 [======================>.......] - ETA: 2s - loss: 0.3937 - mean_squared_error: 0.0013 - metrices: 0.0078
10240/13090 [======================>.......] - ETA: 2s - loss: 0.3935 - mean_squared_error: 0.0013 - metrices: 0.0078
10368/13090 [======================>.......] - ETA: 2s - loss: 0.3939 - mean_squared_error: 0.0013 - metrices: 0.0078
10496/13090 [=======================>......] - ETA: 2s - loss: 0.3956 - mean_squared_error: 0.0013 - metrices: 0.0078
10624/13090 [=======================>......] - ETA: 2s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
10752/13090 [=======================>......] - ETA: 2s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
10880/13090 [=======================>......] - ETA: 2s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
11008/13090 [========================>.....] - ETA: 1s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
11136/13090 [========================>.....] - ETA: 1s - loss: 0.3951 - mean_squared_error: 0.0013 - metrices: 0.0078
11264/13090 [========================>.....] - ETA: 1s - loss: 0.3951 - mean_squared_error: 0.0013 - metrices: 0.0078
11392/13090 [=========================>....] - ETA: 1s - loss: 0.3950 - mean_squared_error: 0.0013 - metrices: 0.0078
11520/13090 [=========================>....] - ETA: 1s - loss: 0.3956 - mean_squared_error: 0.0013 - metrices: 0.0078
11648/13090 [=========================>....] - ETA: 1s - loss: 0.3957 - mean_squared_error: 0.0013 - metrices: 0.0078
11776/13090 [=========================>....] - ETA: 1s - loss: 0.3957 - mean_squared_error: 0.0013 - metrices: 0.0078
11904/13090 [==========================>...] - ETA: 1s - loss: 0.3973 - mean_squared_error: 0.0013 - metrices: 0.0078
12032/13090 [==========================>...] - ETA: 0s - loss: 0.3970 - mean_squared_error: 0.0013 - metrices: 0.0078
12160/13090 [==========================>...] - ETA: 0s - loss: 0.3978 - mean_squared_error: 0.0013 - metrices: 0.0078
12288/13090 [===========================>..] - ETA: 0s - loss: 0.3978 - mean_squared_error: 0.0013 - metrices: 0.0078
12416/13090 [===========================>..] - ETA: 0s - loss: 0.3973 - mean_squared_error: 0.0013 - metrices: 0.0078
12544/13090 [===========================>..] - ETA: 0s - loss: 0.3972 - mean_squared_error: 0.0013 - metrices: 0.0078
12672/13090 [============================>.] - ETA: 0s - loss: 0.3969 - mean_squared_error: 0.0013 - metrices: 0.0078
12800/13090 [============================>.] - ETA: 0s - loss: 0.3970 - mean_squared_error: 0.0013 - metrices: 0.0078
12928/13090 [============================>.] - ETA: 0s - loss: 0.3967 - mean_squared_error: 0.0013 - metrices: 0.0078
13056/13090 [============================>.] - ETA: 0s - loss: 0.3964 - mean_squared_error: 0.0013 - metrices: 0.0078
13090/13090 [==============================] - 13s 966us/step - loss: 0.3965 - mean_squared_error: 0.0013 - metrices: 0.0078 - val_loss: 0.4074 - val_mean_squared_error: 0.0013 - val_metrices: 0.0078

但是,如果我像这样将代码块放在build\u vae()方法中,效果会很好。这是怎么发生的?你知道吗

def build_vqvae(x_train, input_shape, embedding_dim=64, num_embeddings=128, commitment_cost=0.25):
    '''
    VQ-VAE Hyper Parameters.
    embedding_dim = 64 # Length of embedding vectors.
    num_embeddings = 128 # Number of embedding vectors (high value = high bottleneck capacity).
    commitment_cost = 0.25 # Controls the weighting of the loss terms.
    '''
    num_hiddens = 128
    num_residual_hiddens = 32
    num_residual_layers = 2
    assert(x_train.ndim == 4)
    # Build modules
    encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
    decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
    pre_vq_conv1 = layers.Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), padding="same", name="to_vq")
    input_img = Input(shape=input_shape)
    h = encoder(input_img)

    # VQVAELayer.
    enc = Conv2D(embedding_dim, kernel_size=(1, 1), strides=(1, 1), name="pre_vqvae")(h)
    enc_inputs = enc
    enc, perplexity = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, name="vqvae")(enc)
    x = Lambda(lambda enc: enc_inputs + K.stop_gradient(enc - enc_inputs), name="encoded")(enc)
    data_variance = np.var(x_train)
    loss = vq_vae_loss_wrapper(data_variance, commitment_cost, enc, enc_inputs)
    perplexity = metrices_perplexity_wrapper(perplexity)

    h = layers.Conv2D(
        filters=num_hiddens,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding="same")(x)

    h = residual_stack(
        h,
        num_hiddens,
        num_residual_layers,
        num_residual_hiddens)

    h = layers.Conv2DTranspose(
        filters=int(num_hiddens / 2),
        kernel_size=(4, 4),
        strides=(2, 2),
        padding="same")(h)
    h = layers.ReLU()(h)

    x_recon = layers.Conv2DTranspose(
        filters=1,
        kernel_size=(4, 4),
        strides=(2, 2),
        padding="same")(h)
    vqvae = Model(input_img, x_recon)

    return vqvae, loss, perplexity

结果只适用于少数几个时期的培训:

 4224/13090 [========>.....................] - ETA: 10s - loss: 0.0413 - mean_squared_error: 1.4056e-04 - metrices: 0.0078
 4352/13090 [========>.....................] - ETA: 10s - loss: 0.0412 - mean_squared_error: 1.4039e-04 - metrices: 0.0078
 4480/13090 [=========>....................] - ETA: 10s - loss: 0.0413 - mean_squared_error: 1.4049e-04 - metrices: 0.0078
 4608/13090 [=========>....................] - ETA: 9s - loss: 0.0411 - mean_squared_error: 1.3994e-04 - metrices: 0.0078 
 4736/13090 [=========>....................] - ETA: 9s - loss: 0.0409 - mean_squared_error: 1.3917e-04 - metrices: 0.0078
 4864/13090 [==========>...................] - ETA: 9s - loss: 0.0407 - mean_squared_error: 1.3861e-04 - metrices: 0.0078
 4992/13090 [==========>...................] - ETA: 9s - loss: 0.0406 - mean_squared_error: 1.3821e-04 - metrices: 0.0078
 5120/13090 [==========>...................] - ETA: 9s - loss: 0.0405 - mean_squared_error: 1.3785e-04 - metrices: 0.0078
 5248/13090 [===========>..................] - ETA: 9s - loss: 0.0403 - mean_squared_error: 1.3733e-04 - metrices: 0.0078
 5376/13090 [===========>..................] - ETA: 8s - loss: 0.0401 - mean_squared_error: 1.3672e-04 - metrices: 0.0078
 5504/13090 [===========>..................] - ETA: 8s - loss: 0.0401 - mean_squared_error: 1.3657e-04 - metrices: 0.0078
 5632/13090 [===========>..................] - ETA: 8s - loss: 0.0401 - mean_squared_error: 1.3655e-04 - metrices: 0.0078
 5760/13090 [============>.................] - ETA: 8s - loss: 0.0400 - mean_squared_error: 1.3619e-04 - metrices: 0.0078
 5888/13090 [============>.................] - ETA: 8s - loss: 0.0400 - mean_squared_error: 1.3606e-04 - metrices: 0.0078
 6016/13090 [============>.................] - ETA: 8s - loss: 0.0399 - mean_squared_error: 1.3579e-04 - metrices: 0.0078
 6144/13090 [=============>................] - ETA: 8s - loss: 0.0398 - mean_squared_error: 1.3544e-04 - metrices: 0.0078
 6272/13090 [=============>................] - ETA: 7s - loss: 0.0397 - mean_squared_error: 1.3532e-04 - metrices: 0.0078
 6400/13090 [=============>................] - ETA: 7s - loss: 0.0397 - mean_squared_error: 1.3505e-04 - metrices: 0.0078
 6528/13090 [=============>................] - ETA: 7s - loss: 0.0403 - mean_squared_error: 1.3727e-04 - metrices: 0.0078
 6656/13090 [==============>...............] - ETA: 7s - loss: 0.0401 - mean_squared_error: 1.3673e-04 - metrices: 0.0078
 6784/13090 [==============>...............] - ETA: 7s - loss: 0.0401 - mean_squared_error: 1.3641e-04 - metrices: 0.0078
 6912/13090 [==============>...............] - ETA: 7s - loss: 0.0403 - mean_squared_error: 1.3722e-04 - metrices: 0.0078
 7040/13090 [===============>..............] - ETA: 7s - loss: 0.0402 - mean_squared_error: 1.3697e-04 - metrices: 0.0078
 7168/13090 [===============>..............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3670e-04 - metrices: 0.0078
 7296/13090 [===============>..............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3665e-04 - metrices: 0.0078
 7424/13090 [================>.............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3666e-04 - metrices: 0.0078
 7552/13090 [================>.............] - ETA: 6s - loss: 0.0401 - mean_squared_error: 1.3655e-04 - metrices: 0.0078

Tags: selfinputsizelayerserrormeankernelnum