From 1fe06e82c1a86af4ad649ff38e1a765f9a40c6e4 Mon Sep 17 00:00:00 2001 From: park beom su <poj4639@ajou.ac.kr> Date: Thu, 15 Dec 2022 21:36:43 +0900 Subject: [PATCH] Upload New File --- model.py | 239 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 model.py diff --git a/model.py b/model.py new file mode 100644 index 0000000..0f96d8b --- /dev/null +++ b/model.py @@ -0,0 +1,239 @@ +from tensorflow import keras +import tensorflow as tf + + +class FastSRGAN(object): + """SRGAN for fast super resolution.""" + + def __init__(self, args): + """ + Initializes the Mobile SRGAN class. + Args: + args: CLI arguments that dictate how to build the model. + Returns: + None + """ + self.hr_height = args.hr_size + self.hr_width = args.hr_size + self.lr_height = self.hr_height // 4 # Low resolution height + self.lr_width = self.hr_width // 4 # Low resolution width + self.lr_shape = (self.lr_height, self.lr_width, 3) + self.hr_shape = (self.hr_height, self.hr_width, 3) + self.iterations = 0 + + # Number of inverted residual blocks in the mobilenet generator + self.n_residual_blocks = 6 + + # Define a learning rate decay schedule. + self.gen_schedule = keras.optimizers.schedules.ExponentialDecay( + args.lr, + decay_steps=100000, + decay_rate=0.1, + staircase=True + ) + + self.disc_schedule = keras.optimizers.schedules.ExponentialDecay( + args.lr * 5, # TTUR - Two Time Scale Updates + decay_steps=100000, + decay_rate=0.1, + staircase=True + ) + + self.gen_optimizer = keras.optimizers.Adam(learning_rate=self.gen_schedule) + self.disc_optimizer = keras.optimizers.Adam(learning_rate=self.disc_schedule) + + # We use a pre-trained VGG19 model to extract image features from the high resolution + # and the generated high resolution images and minimize the mse between them + self.vgg = self.build_vgg() + self.vgg.trainable = False + + # Calculate output shape of D (PatchGAN) + patch = int(self.hr_height / 2 ** 4) + self.disc_patch = (patch, patch, 1) + + # Number of filters in the first layer of G and D + self.gf = 32 # Realtime Image Enhancement GAN Galteri et al. + self.df = 32 + + # Build and compile the discriminator + self.discriminator = self.build_discriminator() + + # Build and compile the generator for pretraining. + self.generator = self.build_generator() + + @tf.function + def content_loss(self, hr, sr): + sr = keras.applications.vgg19.preprocess_input(((sr + 1.0) * 255) / 2.0) + hr = keras.applications.vgg19.preprocess_input(((hr + 1.0) * 255) / 2.0) + sr_features = self.vgg(sr) / 12.75 + hr_features = self.vgg(hr) / 12.75 + return tf.keras.losses.MeanSquaredError()(hr_features, sr_features) + + def build_vgg(self): + """ + Builds a pre-trained VGG19 model that outputs image features extracted at the + third block of the model + """ + # Get the vgg network. Extract features from Block 5, last convolution. + vgg = keras.applications.VGG19(weights="imagenet", input_shape=self.hr_shape, include_top=False) + vgg.trainable = False + for layer in vgg.layers: + layer.trainable = False + + # Create model and compile + model = keras.models.Model(inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output) + + return model + + def build_generator(self): + """Build the generator that will do the Super Resolution task. + Based on the Mobilenet design. Idea from Galteri et al.""" + + def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + def residual_block(inputs, filters, block_id, expansion=6, stride=1, alpha=1.0): + """Inverted Residual block that uses depth wise convolutions for parameter efficiency. + Args: + inputs: The input feature map. + filters: Number of filters in each convolution in the block. + block_id: An integer specifier for the id of the block in the graph. + expansion: Channel expansion factor. + stride: The stride of the convolution. + alpha: Depth expansion factor. + Returns: + x: The output of the inverted residual block. + """ + channel_axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1 + + in_channels = keras.backend.int_shape(inputs)[channel_axis] + pointwise_conv_filters = int(filters * alpha) + pointwise_filters = _make_divisible(pointwise_conv_filters, 8) + x = inputs + prefix = 'block_{}_'.format(block_id) + + if block_id: + # Expand + x = keras.layers.Conv2D(expansion * in_channels, + kernel_size=1, + padding='same', + use_bias=True, + activation=None, + name=prefix + 'expand')(x) + x = keras.layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=prefix + 'expand_BN')(x) + x = keras.layers.Activation('relu', name=prefix + 'expand_relu')(x) + else: + prefix = 'expanded_conv_' + + # Depthwise + x = keras.layers.DepthwiseConv2D(kernel_size=3, + strides=stride, + activation=None, + use_bias=True, + padding='same' if stride == 1 else 'valid', + name=prefix + 'depthwise')(x) + x = keras.layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=prefix + 'depthwise_BN')(x) + + x = keras.layers.Activation('relu', name=prefix + 'depthwise_relu')(x) + + # Project + x = keras.layers.Conv2D(pointwise_filters, + kernel_size=1, + padding='same', + use_bias=True, + activation=None, + name=prefix + 'project')(x) + x = keras.layers.BatchNormalization(axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=prefix + 'project_BN')(x) + + if in_channels == pointwise_filters and stride == 1: + return keras.layers.Add(name=prefix + 'add')([inputs, x]) + return x + + def deconv2d(layer_input): + """Upsampling layer to increase height and width of the input. + Uses PixelShuffle for upsampling. + Args: + layer_input: The input tensor to upsample. + Returns: + u: Upsampled input by a factor of 2. + """ + u = keras.layers.UpSampling2D(size=2, interpolation='bilinear')(layer_input) + u = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(u) + u = keras.layers.PReLU(shared_axes=[1, 2])(u) + return u + + # Low resolution image input + img_lr = keras.Input(shape=self.lr_shape) + + # Pre-residual block + c1 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(img_lr) + c1 = keras.layers.BatchNormalization()(c1) + c1 = keras.layers.PReLU(shared_axes=[1, 2])(c1) + + # Propogate through residual blocks + r = residual_block(c1, self.gf, 0) + for idx in range(1, self.n_residual_blocks): + r = residual_block(r, self.gf, idx) + + # Post-residual block + c2 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(r) + c2 = keras.layers.BatchNormalization()(c2) + c2 = keras.layers.Add()([c2, c1]) + + # Upsampling + u1 = deconv2d(c2) + u2 = deconv2d(u1) + + # Generate high resolution output + gen_hr = keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')(u2) + + return keras.models.Model(img_lr, gen_hr) + + def build_discriminator(self): + """Builds a discriminator network based on the SRGAN design.""" + + def d_block(layer_input, filters, strides=1, bn=True): + """Discriminator layer block. + Args: + layer_input: Input feature map for the convolutional block. + filters: Number of filters in the convolution. + strides: The stride of the convolution. + bn: Whether to use batch norm or not. + """ + d = keras.layers.Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input) + if bn: + d = keras.layers.BatchNormalization(momentum=0.8)(d) + d = keras.layers.LeakyReLU(alpha=0.2)(d) + + return d + + # Input img + d0 = keras.layers.Input(shape=self.hr_shape) + + d1 = d_block(d0, self.df, bn=False) + d2 = d_block(d1, self.df, strides=2) + d3 = d_block(d2, self.df) + d4 = d_block(d3, self.df, strides=2) + d5 = d_block(d4, self.df * 2) + d6 = d_block(d5, self.df * 2, strides=2) + d7 = d_block(d6, self.df * 2) + d8 = d_block(d7, self.df * 2, strides=2) + + validity = keras.layers.Conv2D(1, kernel_size=1, strides=1, activation='sigmoid', padding='same')(d8) + + return keras.models.Model(d0, validity) -- GitLab