diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d2dbf19b027bbb902449bc0efef206c12485911a --- /dev/null +++ b/main.py @@ -0,0 +1,161 @@ +from argparse import ArgumentParser +from dataloader import DataLoader +from model import FastSRGAN +import tensorflow as tf +import os + +parser = ArgumentParser() +parser.add_argument('--image_dir', type=str, help='Path to high resolution image directory.') +parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training.') +parser.add_argument('--epochs', default=1, type=int, help='Number of epochs for training') +parser.add_argument('--hr_size', default=384, type=int, help='Low resolution input size.') +parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate for optimizers.') +parser.add_argument('--save_iter', default=200, type=int, + help='The number of iterations to save the tensorboard summaries and models.') + + +@tf.function +def pretrain_step(model, x, y): + """ + Single step of generator pre-training. + Args: + model: A model object with a tf keras compiled generator. + x: The low resolution image tensor. + y: The high resolution image tensor. + """ + with tf.GradientTape() as tape: + fake_hr = model.generator(x) + loss_mse = tf.keras.losses.MeanSquaredError()(y, fake_hr) + + grads = tape.gradient(loss_mse, model.generator.trainable_variables) + model.gen_optimizer.apply_gradients(zip(grads, model.generator.trainable_variables)) + + return loss_mse + + +def pretrain_generator(model, dataset, writer): + """Function that pretrains the generator slightly, to avoid local minima. + Args: + model: The keras model to train. + dataset: A tf dataset object of low and high res images to pretrain over. + writer: A summary writer object. + Returns: + None + """ + with writer.as_default(): + iteration = 0 + for _ in range(1): + for x, y in dataset: + loss = pretrain_step(model, x, y) + if iteration % 20 == 0: + tf.summary.scalar('MSE Loss', loss, step=tf.cast(iteration, tf.int64)) + writer.flush() + iteration += 1 + + +@tf.function +def train_step(model, x, y): + """Single train step function for the SRGAN. + Args: + model: An object that contains a tf keras compiled discriminator model. + x: The low resolution input image. + y: The desired high resolution output image. + + Returns: + d_loss: The mean loss of the discriminator. + """ + # Label smoothing for better gradient flow + valid = tf.ones((x.shape[0],) + model.disc_patch) + fake = tf.zeros((x.shape[0],) + model.disc_patch) + + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: + # From low res. image generate high res. version + fake_hr = model.generator(x) + + # Train the discriminators (original images = real / generated = Fake) + valid_prediction = model.discriminator(y) + fake_prediction = model.discriminator(fake_hr) + + # Generator loss + content_loss = model.content_loss(y, fake_hr) + adv_loss = 1e-3 * tf.keras.losses.BinaryCrossentropy()(valid, fake_prediction) + mse_loss = tf.keras.losses.MeanSquaredError()(y, fake_hr) + perceptual_loss = content_loss + adv_loss + mse_loss + + # Discriminator loss + valid_loss = tf.keras.losses.BinaryCrossentropy()(valid, valid_prediction) + fake_loss = tf.keras.losses.BinaryCrossentropy()(fake, fake_prediction) + d_loss = tf.add(valid_loss, fake_loss) + + # Backprop on Generator + gen_grads = gen_tape.gradient(perceptual_loss, model.generator.trainable_variables) + model.gen_optimizer.apply_gradients(zip(gen_grads, model.generator.trainable_variables)) + + # Backprop on Discriminator + disc_grads = disc_tape.gradient(d_loss, model.discriminator.trainable_variables) + model.disc_optimizer.apply_gradients(zip(disc_grads, model.discriminator.trainable_variables)) + + return d_loss, adv_loss, content_loss, mse_loss + + +def train(model, dataset, log_iter, writer): + """ + Function that defines a single training step for the SR-GAN. + Args: + model: An object that contains tf keras compiled generator and + discriminator models. + dataset: A tf data object that contains low and high res images. + log_iter: Number of iterations after which to add logs in + tensorboard. + writer: Summary writer + """ + with writer.as_default(): + # Iterate over dataset + for x, y in dataset: + disc_loss, adv_loss, content_loss, mse_loss = train_step(model, x, y) + # Log tensorboard summaries if log iteration is reached. + if model.iterations % log_iter == 0: + tf.summary.scalar('Adversarial Loss', adv_loss, step=model.iterations) + tf.summary.scalar('Content Loss', content_loss, step=model.iterations) + tf.summary.scalar('MSE Loss', mse_loss, step=model.iterations) + tf.summary.scalar('Discriminator Loss', disc_loss, step=model.iterations) + tf.summary.image('Low Res', tf.cast(255 * x, tf.uint8), step=model.iterations) + tf.summary.image('High Res', tf.cast(255 * (y + 1.0) / 2.0, tf.uint8), step=model.iterations) + tf.summary.image('Generated', tf.cast(255 * (model.generator.predict(x) + 1.0) / 2.0, tf.uint8), + step=model.iterations) + model.generator.save('models/generator.h5') + model.discriminator.save('models/discriminator.h5') + writer.flush() + model.iterations += 1 + + +def main(): + # Parse the CLI arguments. + args = parser.parse_args() + + # create directory for saving trained models. + if not os.path.exists('models'): + os.makedirs('models') + + # Create the tensorflow dataset. + ds = DataLoader(args.image_dir, args.hr_size).dataset(args.batch_size) + + # Initialize the GAN object. + gan = FastSRGAN(args) + + # Define the directory for saving pretrainig loss tensorboard summary. + pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain') + + # Run pre-training. + pretrain_generator(gan, ds, pretrain_summary_writer) + + # Define the directory for saving the SRGAN training tensorbaord summary. + train_summary_writer = tf.summary.create_file_writer('logs/train') + + # Run training. + for _ in range(args.epochs): + train(gan, ds, args.save_iter, train_summary_writer) + + +if __name__ == '__main__': + main()