diff --git a/src/loss/adversarial.py b/src/loss/adversarial.py index c4b7a4a90e234107130fe2f062a4c1b9f3429216..7517335e66c1b1a84065722c44ddb743abdf3055 100644 --- a/src/loss/adversarial.py +++ b/src/loss/adversarial.py @@ -1,4 +1,6 @@ import utility +from types import SimpleNamespace + from model import common from loss import discriminator @@ -12,37 +14,43 @@ class Adversarial(nn.Module): super(Adversarial, self).__init__() self.gan_type = gan_type self.gan_k = args.gan_k - self.discriminator = discriminator.Discriminator(args, gan_type) - if gan_type != 'WGAN_GP': - self.optimizer = utility.make_optimizer(args, self.discriminator) + self.dis = discriminator.Discriminator(args) + if gan_type == 'WGAN_GP': + # see https://arxiv.org/pdf/1704.00028.pdf pp.4 + optim_dict = { + 'optimizer': 'ADAM', + 'betas': (0, 0.9), + 'epsilon': 1e-8, + 'lr': 1e-5, + 'weight_decay': args.weight_decay, + 'decay': args.decay, + 'gamma': args.gamma + } + optim_args = SimpleNamespace(**optim_dict) else: - self.optimizer = optim.Adam( - self.discriminator.parameters(), - betas=(0, 0.9), eps=1e-8, lr=1e-5 - ) - self.scheduler = utility.make_scheduler(args, self.optimizer) + optim_args = args - def forward(self, fake, real): - fake_detach = fake.detach() + self.optimizer = utility.make_optimizer(optim_args, self.dis) + def forward(self, fake, real): + # updating discriminator... self.loss = 0 + fake_detach = fake.detach() # do not backpropagate through G for _ in range(self.gan_k): self.optimizer.zero_grad() - d_fake = self.discriminator(fake_detach) - d_real = self.discriminator(real) + # d: B x 1 tensor + d_fake = self.dis(fake_detach) + d_real = self.dis(real) + retain_graph = False if self.gan_type == 'GAN': - label_fake = torch.zeros_like(d_fake) - label_real = torch.ones_like(d_real) - loss_d \ - = F.binary_cross_entropy_with_logits(d_fake, label_fake) \ - + F.binary_cross_entropy_with_logits(d_real, label_real) + loss_d = self.bce(d_real, d_fake) elif self.gan_type.find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if self.gan_type.find('GP') >= 0: epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat.requires_grad = True - d_hat = self.discriminator(hat) + d_hat = self.dis(hat) gradients = torch.autograd.grad( outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True @@ -51,34 +59,52 @@ class Adversarial(nn.Module): gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty + # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks + elif self.gan_type == 'RGAN': + better_real = d_real - d_fake.mean(dim=0, keepdim=True) + better_fake = d_fake - d_real.mean(dim=0, keepdim=True) + loss_d = self.bce(better_real, better_fake) + retain_graph = True # Discriminator update self.loss += loss_d.item() - loss_d.backward() + loss_d.backward(retain_graph=retain_graph) self.optimizer.step() if self.gan_type == 'WGAN': - for p in self.discriminator.parameters(): + for p in self.dis.parameters(): p.data.clamp_(-1, 1) self.loss /= self.gan_k - d_fake_for_g = self.discriminator(fake) + # updating generator... + d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is if self.gan_type == 'GAN': - loss_g = F.binary_cross_entropy_with_logits( - d_fake_for_g, label_real - ) + label_real = torch.ones_like(d_fake_bp) + loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) elif self.gan_type.find('WGAN') >= 0: - loss_g = -d_fake_for_g.mean() + loss_g = -d_fake_bp.mean() + elif self.gan_type == 'RGAN': + better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) + better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) + loss_g = self.bce(better_fake, better_real) # Generator loss return loss_g def state_dict(self, *args, **kwargs): - state_discriminator = self.discriminator.state_dict(*args, **kwargs) + state_discriminator = self.dis.state_dict(*args, **kwargs) state_optimizer = self.optimizer.state_dict() return dict(**state_discriminator, **state_optimizer) + + def bce(self, real, fake): + label_real = torch.ones_like(real) + label_fake = torch.zeros_like(fake) + bce_real = F.binary_cross_entropy_with_logits(real, label_real) + bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) + bce_loss = bce_real + bce_fake + return bce_loss # Some references # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py diff --git a/src/loss/discriminator.py b/src/loss/discriminator.py index 53fff1a11d269987ee619964016a44fb9b4de6d8..4581dfeb1a5539daddf03419cff527a20bebe7a0 100644 --- a/src/loss/discriminator.py +++ b/src/loss/discriminator.py @@ -3,7 +3,10 @@ from model import common import torch.nn as nn class Discriminator(nn.Module): - def __init__(self, args, gan_type='GAN'): + ''' + output is not normalized + ''' + def __init__(self, args): super(Discriminator, self).__init__() in_channels = args.n_colors diff --git a/src/loss/vgg.py b/src/loss/vgg.py index a0167f5c8598abf1ad99ec81f32a6e13fefacbc1..335716d56b6bbd0876555f9e297480353288a0d9 100644 --- a/src/loss/vgg.py +++ b/src/loss/vgg.py @@ -10,9 +10,9 @@ class VGG(nn.Module): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] - if conv_index == '22': + if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) - elif conv_index == '54': + elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) diff --git a/src/option.py b/src/option.py index 672933096c14ddb4bf3b05c14382f0d4cdd206e5..8ec9634813b2a3b4799341e0b23e9bac856fc6e6 100644 --- a/src/option.py +++ b/src/option.py @@ -103,9 +103,7 @@ parser.add_argument('--gan_k', type=int, default=1, # Optimization specifications parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') -parser.add_argument('--lr_decay', type=int, default=200, - help='learning rate decay per N epochs') -parser.add_argument('--decay_type', type=str, default='step', +parser.add_argument('--decay', type=str, default='200', help='learning rate decay type') parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') @@ -114,7 +112,7 @@ parser.add_argument('--optimizer', default='ADAM', help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') -parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') diff --git a/src/template.py b/src/template.py index 755a7bce970ffc19c3f6c06f9d28fa84281e5542..2508b867f69a183fd6ed2c0c31effe7893d3326a 100644 --- a/src/template.py +++ b/src/template.py @@ -4,7 +4,7 @@ def set_template(args): args.data_train = 'DIV2K_jpeg' args.data_test = 'DIV2K_jpeg' args.epochs = 200 - args.lr_decay = 100 + args.decay = '100' if args.template.find('EDSR_paper') >= 0: args.model = 'EDSR' @@ -26,7 +26,7 @@ def set_template(args): args.batch_size = 20 args.epochs = 1000 - args.lr_decay = 500 + args.decay = '500' args.gamma = 0.1 args.weight_decay = 1e-4 @@ -35,7 +35,7 @@ def set_template(args): if args.template.find('GAN') >= 0: args.epochs = 200 args.lr = 5e-5 - args.lr_decay = 150 + args.decay = '150' if args.template.find('RCAN') >= 0: args.model = 'RCAN' diff --git a/src/trainer.py b/src/trainer.py index fb73373d2702e7940ad997a3866ff7acadd1064c..40b15a0f6dcfd1f35dd053eefe7a5a890bedd3af 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -19,21 +19,17 @@ class Trainer(): self.model = my_model self.loss = my_loss self.optimizer = utility.make_optimizer(args, self.model) - self.scheduler = utility.make_scheduler(args, self.optimizer) if self.args.load != '': - self.optimizer.load_state_dict( - torch.load(os.path.join(ckp.dir, 'optimizer.pt')) - ) - for _ in range(len(ckp.log)): self.scheduler.step() + self.optimizer.load(ckp.dir, epoch=len(ckp.log)) self.error_last = 1e8 def train(self): - self.scheduler.step() + self.optimizer.schedule() self.loss.step() - epoch = self.scheduler.last_epoch + 1 - lr = self.scheduler.get_lr()[0] + epoch = self.optimizer.get_last_epoch() + 1 + lr = self.optimizer.get_lr() self.ckp.write_log( '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) @@ -76,7 +72,7 @@ class Trainer(): def test(self): torch.set_grad_enabled(False) - epoch = self.scheduler.last_epoch + 1 + epoch = self.optimizer.get_last_epoch() + 1 self.ckp.write_log('\nEvaluation:') self.ckp.add_log( torch.zeros(1, len(self.loader_test), len(self.scale)) @@ -118,7 +114,9 @@ class Trainer(): self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) self.ckp.write_log('Saving...') - if self.args.save_results: self.ckp.end_background() + if self.args.save_results: + self.ckp.end_background() + if not self.args.test_only: self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) @@ -141,6 +139,6 @@ class Trainer(): self.test() return True else: - epoch = self.scheduler.last_epoch + 1 + epoch = self.optimizer.get_last_epoch() + 1 return epoch >= self.args.epochs diff --git a/src/utility.py b/src/utility.py index 25aabd9f554fce1cbd6d830d8985ed9c6dad5276..7da69a701080a1a12285d6486d967befe6e6bce1 100644 --- a/src/utility.py +++ b/src/utility.py @@ -88,11 +88,8 @@ class checkpoint(): trainer.loss.plot_loss(self.dir, epoch) self.plot_psnr(epoch) + trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt')) - torch.save( - trainer.optimizer.state_dict(), - self.get_path('optimizer.pt') - ) def add_log(self, log): self.log = torch.cat([self.log, log]) @@ -183,35 +180,58 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None): return -10 * math.log10(mse) -def make_optimizer(args, my_model): - trainable = filter(lambda x: x.requires_grad, my_model.parameters()) +def make_optimizer(args, target): + ''' + make optimizer and scheduler together + ''' + # optimizer + trainable = filter(lambda x: x.requires_grad, target.parameters()) + kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == 'SGD': - optimizer_function = optim.SGD - kwargs = {'momentum': args.momentum} + optimizer_class = optim.SGD + kwargs_optimizer['momentum'] = args.momentum elif args.optimizer == 'ADAM': - optimizer_function = optim.Adam - kwargs = { - 'betas': args.beta, - 'eps': args.epsilon - } + optimizer_class = optim.Adam + kwargs_optimizer['betas'] = args.betas + kwargs_optimizer['eps'] = args.epsilon elif args.optimizer == 'RMSprop': - optimizer_function = optim.RMSprop - kwargs = {'eps': args.epsilon} + optimizer_class = optim.RMSprop + kwargs_optimizer['eps'] = args.epsilon - kwargs['lr'] = args.lr - kwargs['weight_decay'] = args.weight_decay + # scheduler + milestones = list(map(lambda x: int(x), args.decay.split('-'))) + kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} + scheduler_class = lrs.MultiStepLR + + class CustomOptimizer(optimizer_class): + def __init__(self, *args, **kwargs): + super(CustomOptimizer, self).__init__(*args, **kwargs) + + def _register_scheduler(self, scheduler_class, **kwargs): + self.scheduler = scheduler_class(self, **kwargs) + + def save(self, save_dir): + torch.save(self.state_dict(), self.get_dir(save_dir)) + + def load(self, load_dir, epoch=1): + self.load_state_dict(torch.load(self.get_dir(load_dir))) + if epoch > 1: + for _ in range(epoch): self.scheduler.step() + + def get_dir(self, dir_path): + return os.path.join(dir_path, 'optimizer.pt') + + def schedule(self): + self.scheduler.step() + + def get_lr(self): + return self.scheduler.get_lr()[0] + + def get_last_epoch(self): + return self.scheduler.last_epoch - return optimizer_function(trainable, **kwargs) - -def make_scheduler(args, my_optimizer): - if args.decay_type == 'step': - scheduler_function = lrs.StepLR - kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma} - elif args.decay_type.find('step') >= 0: - scheduler_function = lrs.MultiStepLR - milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:])) - kwarg = {'milestones': milestones, 'gamma': args.gamma} - - return scheduler_function(my_optimizer, **kwargs) + optimizer = CustomOptimizer(trainable, **kwargs_optimizer) + optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) + return optimizer