tensorflow v2.1使用tf.keras训练DCGAN失败,发生了什么?

2024-04-27 12:36:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用tensorflow.keras(2.1版)来训练DCGAN

当我遵循官方教程(https://www.tensorflow.org/tutorials/generative/dcgan)时,官方代码已经成功地得到了培训

然而,当我试图像下面这样重写时,训练结果失败了

结果看起来噪声和损失几乎是相同的值,而与训练迭代无关

我不知道是什么导致了

%tensorflow_version 2.x
import tensorflow as tf

print(tf.__version__)

import argparse
import cv2
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
from tensorflow.keras import layers
from tensorflow.keras.layers import *
from tensorflow.keras.initializers import RandomNormal as RN, Constant
import pickle
import os

# config
class_N = 2
img_height, img_width = 32, 32
channel = 3

# GAN config
Z_dim = 100

# model path
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")



def Generator():
    inputs = Input((Z_dim,))
    in_h = int(img_height / 16)
    in_w = int(img_width / 16)
    base = 128
    # 1/16
    x = Dense(in_h * in_w * base, name='g_dense1',
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(inputs)
    x = Reshape((in_h, in_w, base), input_shape=(base * in_h * in_w,))(x)
    x = Activation('relu')(x)
    x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_dense1_bn')(x)
    # 1/8
    x = Conv2DTranspose(base*4, (5, 5), name='g_conv1', padding='same', strides=(2,2),
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
    x = Activation('relu')(x)
    x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_conv1_bn')(x)
    # 1/4
    x = Conv2DTranspose(base*2, (5, 5), name='g_conv2', padding='same', strides=(2,2),
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
    x = Activation('relu')(x)
    x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_conv2_bn')(x)
    # 1/2
    x = Conv2DTranspose(base, (5, 5), name='g_conv3', padding='same', strides=(2,2),
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
    x = Activation('relu')(x)
    x = BatchNormalization(momentum=0.9, epsilon=1e-5, name='g_conv3_bn')(x)
    # 1/1
    x = Conv2DTranspose(channel, (5, 5), name='g_out', padding='same', strides=(2,2),
        kernel_initializer=RN(mean=0.0, stddev=0.02),  bias_initializer=Constant())(x)
    x = Activation('tanh')(x)
    model = tf.keras.Model(inputs=inputs, outputs=x, name='G')
    return model


def Discriminator():
    base = 32
    inputs = Input((img_height, img_width, channel))
    x = Conv2D(base, (5, 5), padding='same', strides=(2,2), name='d_conv1',
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(base*2, (5, 5), padding='same', strides=(2,2), name='d_conv2',
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(base*4, (5, 5), padding='same', strides=(2,2), name='d_conv3',
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(base*8, (5, 5), padding='same', strides=(2,2), name='d_conv4',
        kernel_initializer=RN(mean=0.0, stddev=0.02), use_bias=False)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Flatten()(x)
    x = Dense(1, name='d_out',
        kernel_initializer=RN(mean=0.0, stddev=0.02), bias_initializer=Constant())(x)
    model = tf.keras.Model(inputs=inputs, outputs=x, name='D')
    return model

def load_cifar10():
    path = 'drive/My Drive/Colab Notebooks/' + 'cifar-10-batches-py'

    if not os.path.exists(path):
        os.system("wget {}".format(path))
        os.system("tar xvf {}".format(path))

    # train data
    train_x = np.ndarray([0, 32, 32, 3], dtype=np.float32)
    train_y = np.ndarray([0, ], dtype=np.int)

    for i in range(1, 6):
        data_path = path + '/data_batch_{}'.format(i)
        with open(data_path, 'rb') as f:
            datas = pickle.load(f, encoding='bytes')
            print(data_path)
            x = datas[b'data']
            x = x.reshape(x.shape[0], 3, 32, 32)
            x = x.transpose(0, 2, 3, 1)
            train_x = np.vstack((train_x, x))

            y = np.array(datas[b'labels'], dtype=np.int)
            train_y = np.hstack((train_y, y))

    # test data
    data_path = path + '/test_batch'

    with open(data_path, 'rb') as f:
        datas = pickle.load(f, encoding='bytes')
        print(data_path)
        x = datas[b'data']
        x = x.reshape(x.shape[0], 3, 32, 32)
        test_x = x.transpose(0, 2, 3, 1)

        test_y = np.array(datas[b'labels'], dtype=np.int)

    return train_x, train_y, test_x, test_y


# train
def train():
    # model
    G = Generator()
    D = Discriminator()

    train_x, train_y, test_x, test_y = load_cifar10()
    xs = train_x / 127.5 - 1

    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # training
    mb = 64
    mbi = 0
    train_N = len(xs)
    train_ind = np.arange(train_N)
    np.random.seed(0)

    @tf.function
    def train_iter(x, z):
        with tf.GradientTape() as G_tape, tf.GradientTape() as D_tape:
            # feed forward
            # z -> G -> Gz
            Gz = G(z, training=True)

            # x -> D -> Dx
            # z -> G -> Gz -> D -> DGz
            Dx = D(x, training=True)
            DGz = D(Gz, training=True)

            # get loss
            loss_G = loss_fn(tf.ones_like(Gz), Gz)
            loss_D_real = loss_fn(tf.ones_like(Dx), Dx)
            loss_D_fake = loss_fn(tf.zeros_like(DGz), DGz)
            loss_D = loss_D_real + loss_D_fake

        # feed back
        gradients_of_G = G_tape.gradient(loss_G, G.trainable_variables)
        gradients_of_D = D_tape.gradient(loss_D, D.trainable_variables)

        # update parameter
        G_optimizer.apply_gradients(zip(gradients_of_G, G.trainable_variables))
        D_optimizer.apply_gradients(zip(gradients_of_D, D.trainable_variables))

        return loss_G, loss_D

    #with strategy.scope():
    # optimizer
    G_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
    D_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

    checkpoint = tf.train.Checkpoint(G_optimizer=G_optimizer, D_optimizer=D_optimizer, G=G, D=D)

    for ite in range(10000):
        if mbi + mb > train_N:
            mb_ind = train_ind[mbi:]
            np.random.shuffle(train_ind)
            mb_ind = np.hstack((mb_ind, train_ind[:(mb - (train_N - mbi))]))
            mbi = mb - (train_N - mbi)
        else:
            mb_ind = train_ind[mbi: mbi+mb]
            mbi += mb

        x = xs[mb_ind]

        z = np.random.uniform(-1, 1, size=(mb, Z_dim))
        #z = tf.random.normal([mb, Z_dim])

        loss_G, loss_D = train_iter(x, z)

        if (ite + 1) % 100 == 0:
            print("iter >>", ite+1, ',G:loss >>', loss_G.numpy(), ',D:loss >>', loss_D.numpy())

        # display generated image
        if (ite + 1) % 1000 == 0:
            Gz = G(z)
            _Gz = (Gz * 127.5 + 127.5).numpy().astype(int)
            for i in range(9):
                plt.subplot(3, 3, i + 1)
                plt.imshow(_Gz[i])
                plt.axis('off')
            plt.show()

    # save model
    checkpoint.save(file_prefix = checkpoint_prefix)

Tags: pathnameinimportdatabasetfnp