diff --git a/src/loss/adversarial.py b/src/loss/adversarial.py
index c4b7a4a90e234107130fe2f062a4c1b9f3429216..7517335e66c1b1a84065722c44ddb743abdf3055 100644
--- a/src/loss/adversarial.py
+++ b/src/loss/adversarial.py
@@ -1,4 +1,6 @@
 import utility
+from types import SimpleNamespace
+
 from model import common
 from loss import discriminator
 
@@ -12,37 +14,43 @@ class Adversarial(nn.Module):
         super(Adversarial, self).__init__()
         self.gan_type = gan_type
         self.gan_k = args.gan_k
-        self.discriminator = discriminator.Discriminator(args, gan_type)
-        if gan_type != 'WGAN_GP':
-            self.optimizer = utility.make_optimizer(args, self.discriminator)
+        self.dis = discriminator.Discriminator(args)
+        if gan_type == 'WGAN_GP':
+            # see https://arxiv.org/pdf/1704.00028.pdf pp.4
+            optim_dict = {
+                'optimizer': 'ADAM',
+                'betas': (0, 0.9),
+                'epsilon': 1e-8,
+                'lr': 1e-5,
+                'weight_decay': args.weight_decay,
+                'decay': args.decay,
+                'gamma': args.gamma
+            }
+            optim_args = SimpleNamespace(**optim_dict)
         else:
-            self.optimizer = optim.Adam(
-                self.discriminator.parameters(),
-                betas=(0, 0.9), eps=1e-8, lr=1e-5
-            )
-        self.scheduler = utility.make_scheduler(args, self.optimizer)
+            optim_args = args
 
-    def forward(self, fake, real):
-        fake_detach = fake.detach()
+        self.optimizer = utility.make_optimizer(optim_args, self.dis)
 
+    def forward(self, fake, real):
+        # updating discriminator...
         self.loss = 0
+        fake_detach = fake.detach()     # do not backpropagate through G
         for _ in range(self.gan_k):
             self.optimizer.zero_grad()
-            d_fake = self.discriminator(fake_detach)
-            d_real = self.discriminator(real)
+            # d: B x 1 tensor
+            d_fake = self.dis(fake_detach)
+            d_real = self.dis(real)
+            retain_graph = False
             if self.gan_type == 'GAN':
-                label_fake = torch.zeros_like(d_fake)
-                label_real = torch.ones_like(d_real)
-                loss_d \
-                    = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
-                    + F.binary_cross_entropy_with_logits(d_real, label_real)
+                loss_d = self.bce(d_real, d_fake)
             elif self.gan_type.find('WGAN') >= 0:
                 loss_d = (d_fake - d_real).mean()
                 if self.gan_type.find('GP') >= 0:
                     epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
                     hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
                     hat.requires_grad = True
-                    d_hat = self.discriminator(hat)
+                    d_hat = self.dis(hat)
                     gradients = torch.autograd.grad(
                         outputs=d_hat.sum(), inputs=hat,
                         retain_graph=True, create_graph=True, only_inputs=True
@@ -51,34 +59,52 @@ class Adversarial(nn.Module):
                     gradient_norm = gradients.norm(2, dim=1)
                     gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                     loss_d += gradient_penalty
+            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
+            elif self.gan_type == 'RGAN':
+                better_real = d_real - d_fake.mean(dim=0, keepdim=True)
+                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
+                loss_d = self.bce(better_real, better_fake)
+                retain_graph = True
 
             # Discriminator update
             self.loss += loss_d.item()
-            loss_d.backward()
+            loss_d.backward(retain_graph=retain_graph)
             self.optimizer.step()
 
             if self.gan_type == 'WGAN':
-                for p in self.discriminator.parameters():
+                for p in self.dis.parameters():
                     p.data.clamp_(-1, 1)
 
         self.loss /= self.gan_k
 
-        d_fake_for_g = self.discriminator(fake)
+        # updating generator...
+        d_fake_bp = self.dis(fake)      # for backpropagation, use fake as it is
         if self.gan_type == 'GAN':
-            loss_g = F.binary_cross_entropy_with_logits(
-                d_fake_for_g, label_real
-            )
+            label_real = torch.ones_like(d_fake_bp)
+            loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
         elif self.gan_type.find('WGAN') >= 0:
-            loss_g = -d_fake_for_g.mean()
+            loss_g = -d_fake_bp.mean()
+        elif self.gan_type == 'RGAN':
+            better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
+            better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
+            loss_g = self.bce(better_fake, better_real)
 
         # Generator loss
         return loss_g
     
     def state_dict(self, *args, **kwargs):
-        state_discriminator = self.discriminator.state_dict(*args, **kwargs)
+        state_discriminator = self.dis.state_dict(*args, **kwargs)
         state_optimizer = self.optimizer.state_dict()
 
         return dict(**state_discriminator, **state_optimizer)
+
+    def bce(self, real, fake):
+        label_real = torch.ones_like(real)
+        label_fake = torch.zeros_like(fake)
+        bce_real = F.binary_cross_entropy_with_logits(real, label_real)
+        bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
+        bce_loss = bce_real + bce_fake
+        return bce_loss
                
 # Some references
 # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
diff --git a/src/loss/discriminator.py b/src/loss/discriminator.py
index 53fff1a11d269987ee619964016a44fb9b4de6d8..4581dfeb1a5539daddf03419cff527a20bebe7a0 100644
--- a/src/loss/discriminator.py
+++ b/src/loss/discriminator.py
@@ -3,7 +3,10 @@ from model import common
 import torch.nn as nn
 
 class Discriminator(nn.Module):
-    def __init__(self, args, gan_type='GAN'):
+    '''
+        output is not normalized
+    '''
+    def __init__(self, args):
         super(Discriminator, self).__init__()
 
         in_channels = args.n_colors
diff --git a/src/loss/vgg.py b/src/loss/vgg.py
index a0167f5c8598abf1ad99ec81f32a6e13fefacbc1..335716d56b6bbd0876555f9e297480353288a0d9 100644
--- a/src/loss/vgg.py
+++ b/src/loss/vgg.py
@@ -10,9 +10,9 @@ class VGG(nn.Module):
         super(VGG, self).__init__()
         vgg_features = models.vgg19(pretrained=True).features
         modules = [m for m in vgg_features]
-        if conv_index == '22':
+        if conv_index.find('22') >= 0:
             self.vgg = nn.Sequential(*modules[:8])
-        elif conv_index == '54':
+        elif conv_index.find('54') >= 0:
             self.vgg = nn.Sequential(*modules[:35])
 
         vgg_mean = (0.485, 0.456, 0.406)
diff --git a/src/option.py b/src/option.py
index 672933096c14ddb4bf3b05c14382f0d4cdd206e5..8ec9634813b2a3b4799341e0b23e9bac856fc6e6 100644
--- a/src/option.py
+++ b/src/option.py
@@ -103,9 +103,7 @@ parser.add_argument('--gan_k', type=int, default=1,
 # Optimization specifications
 parser.add_argument('--lr', type=float, default=1e-4,
                     help='learning rate')
-parser.add_argument('--lr_decay', type=int, default=200,
-                    help='learning rate decay per N epochs')
-parser.add_argument('--decay_type', type=str, default='step',
+parser.add_argument('--decay', type=str, default='200',
                     help='learning rate decay type')
 parser.add_argument('--gamma', type=float, default=0.5,
                     help='learning rate decay factor for step decay')
@@ -114,7 +112,7 @@ parser.add_argument('--optimizer', default='ADAM',
                     help='optimizer to use (SGD | ADAM | RMSprop)')
 parser.add_argument('--momentum', type=float, default=0.9,
                     help='SGD momentum')
-parser.add_argument('--beta', type=tuple, default=(0.9, 0.999),
+parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
                     help='ADAM beta')
 parser.add_argument('--epsilon', type=float, default=1e-8,
                     help='ADAM epsilon for numerical stability')
diff --git a/src/template.py b/src/template.py
index 755a7bce970ffc19c3f6c06f9d28fa84281e5542..2508b867f69a183fd6ed2c0c31effe7893d3326a 100644
--- a/src/template.py
+++ b/src/template.py
@@ -4,7 +4,7 @@ def set_template(args):
         args.data_train = 'DIV2K_jpeg'
         args.data_test = 'DIV2K_jpeg'
         args.epochs = 200
-        args.lr_decay = 100
+        args.decay = '100'
 
     if args.template.find('EDSR_paper') >= 0:
         args.model = 'EDSR'
@@ -26,7 +26,7 @@ def set_template(args):
 
         args.batch_size = 20
         args.epochs = 1000
-        args.lr_decay = 500
+        args.decay = '500'
         args.gamma = 0.1
         args.weight_decay = 1e-4
 
@@ -35,7 +35,7 @@ def set_template(args):
     if args.template.find('GAN') >= 0:
         args.epochs = 200
         args.lr = 5e-5
-        args.lr_decay = 150
+        args.decay = '150'
 
     if args.template.find('RCAN') >= 0:
         args.model = 'RCAN'
diff --git a/src/trainer.py b/src/trainer.py
index fb73373d2702e7940ad997a3866ff7acadd1064c..40b15a0f6dcfd1f35dd053eefe7a5a890bedd3af 100644
--- a/src/trainer.py
+++ b/src/trainer.py
@@ -19,21 +19,17 @@ class Trainer():
         self.model = my_model
         self.loss = my_loss
         self.optimizer = utility.make_optimizer(args, self.model)
-        self.scheduler = utility.make_scheduler(args, self.optimizer)
 
         if self.args.load != '':
-            self.optimizer.load_state_dict(
-                torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
-            )
-            for _ in range(len(ckp.log)): self.scheduler.step()
+            self.optimizer.load(ckp.dir, epoch=len(ckp.log))
 
         self.error_last = 1e8
 
     def train(self):
-        self.scheduler.step()
+        self.optimizer.schedule()
         self.loss.step()
-        epoch = self.scheduler.last_epoch + 1
-        lr = self.scheduler.get_lr()[0]
+        epoch = self.optimizer.get_last_epoch() + 1
+        lr = self.optimizer.get_lr()
 
         self.ckp.write_log(
             '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
@@ -76,7 +72,7 @@ class Trainer():
     def test(self):
         torch.set_grad_enabled(False)
 
-        epoch = self.scheduler.last_epoch + 1
+        epoch = self.optimizer.get_last_epoch() + 1
         self.ckp.write_log('\nEvaluation:')
         self.ckp.add_log(
             torch.zeros(1, len(self.loader_test), len(self.scale))
@@ -118,7 +114,9 @@ class Trainer():
         self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
         self.ckp.write_log('Saving...')
 
-        if self.args.save_results: self.ckp.end_background()
+        if self.args.save_results:
+            self.ckp.end_background()
+
         if not self.args.test_only:
             self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
 
@@ -141,6 +139,6 @@ class Trainer():
             self.test()
             return True
         else:
-            epoch = self.scheduler.last_epoch + 1
+            epoch = self.optimizer.get_last_epoch() + 1
             return epoch >= self.args.epochs
 
diff --git a/src/utility.py b/src/utility.py
index 25aabd9f554fce1cbd6d830d8985ed9c6dad5276..7da69a701080a1a12285d6486d967befe6e6bce1 100644
--- a/src/utility.py
+++ b/src/utility.py
@@ -88,11 +88,8 @@ class checkpoint():
         trainer.loss.plot_loss(self.dir, epoch)
 
         self.plot_psnr(epoch)
+        trainer.optimizer.save(self.dir)
         torch.save(self.log, self.get_path('psnr_log.pt'))
-        torch.save(
-            trainer.optimizer.state_dict(),
-            self.get_path('optimizer.pt')
-        )
 
     def add_log(self, log):
         self.log = torch.cat([self.log, log])
@@ -183,35 +180,58 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
 
     return -10 * math.log10(mse)
 
-def make_optimizer(args, my_model):
-    trainable = filter(lambda x: x.requires_grad, my_model.parameters())
+def make_optimizer(args, target):
+    '''
+        make optimizer and scheduler together
+    '''
+    # optimizer
+    trainable = filter(lambda x: x.requires_grad, target.parameters())
+    kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
 
     if args.optimizer == 'SGD':
-        optimizer_function = optim.SGD
-        kwargs = {'momentum': args.momentum}
+        optimizer_class = optim.SGD
+        kwargs_optimizer['momentum'] = args.momentum
     elif args.optimizer == 'ADAM':
-        optimizer_function = optim.Adam
-        kwargs = {
-            'betas': args.beta,
-            'eps': args.epsilon
-        }
+        optimizer_class = optim.Adam
+        kwargs_optimizer['betas'] = args.betas
+        kwargs_optimizer['eps'] = args.epsilon
     elif args.optimizer == 'RMSprop':
-        optimizer_function = optim.RMSprop
-        kwargs = {'eps': args.epsilon}
+        optimizer_class = optim.RMSprop
+        kwargs_optimizer['eps'] = args.epsilon
 
-    kwargs['lr'] = args.lr
-    kwargs['weight_decay'] = args.weight_decay
+    # scheduler
+    milestones = list(map(lambda x: int(x), args.decay.split('-')))
+    kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
+    scheduler_class = lrs.MultiStepLR
+
+    class CustomOptimizer(optimizer_class):
+        def __init__(self, *args, **kwargs):
+            super(CustomOptimizer, self).__init__(*args, **kwargs)
+
+        def _register_scheduler(self, scheduler_class, **kwargs):
+            self.scheduler = scheduler_class(self, **kwargs)
+
+        def save(self, save_dir):
+            torch.save(self.state_dict(), self.get_dir(save_dir))
+
+        def load(self, load_dir, epoch=1):
+            self.load_state_dict(torch.load(self.get_dir(load_dir)))
+            if epoch > 1:
+                for _ in range(epoch): self.scheduler.step()
+
+        def get_dir(self, dir_path):
+            return os.path.join(dir_path, 'optimizer.pt')
+
+        def schedule(self):
+            self.scheduler.step()
+
+        def get_lr(self):
+            return self.scheduler.get_lr()[0]
+
+        def get_last_epoch(self):
+            return self.scheduler.last_epoch
     
-    return optimizer_function(trainable, **kwargs)
-
-def make_scheduler(args, my_optimizer):
-    if args.decay_type == 'step':
-        scheduler_function = lrs.StepLR
-        kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma}
-    elif args.decay_type.find('step') >= 0:
-        scheduler_function = lrs.MultiStepLR
-        milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:]))
-        kwarg = {'milestones': milestones, 'gamma': args.gamma}
-
-    return scheduler_function(my_optimizer, **kwargs)
+    optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
+    optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
+    return optimizer