Skip to content
Snippets Groups Projects
Commit 72a506ae authored by Sanghyun Son's avatar Sanghyun Son
Browse files

test relativistic discriminator

    from ECCV PIRM2018:
        ESRGAN: Enhanced Super-Resolution Generative Adversarial
        Networks
parent 70294bb6
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
import utility import utility
from types import SimpleNamespace
from model import common from model import common
from loss import discriminator from loss import discriminator
...@@ -12,37 +14,43 @@ class Adversarial(nn.Module): ...@@ -12,37 +14,43 @@ class Adversarial(nn.Module):
super(Adversarial, self).__init__() super(Adversarial, self).__init__()
self.gan_type = gan_type self.gan_type = gan_type
self.gan_k = args.gan_k self.gan_k = args.gan_k
self.discriminator = discriminator.Discriminator(args, gan_type) self.dis = discriminator.Discriminator(args)
if gan_type != 'WGAN_GP': if gan_type == 'WGAN_GP':
self.optimizer = utility.make_optimizer(args, self.discriminator) # see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else: else:
self.optimizer = optim.Adam( optim_args = args
self.discriminator.parameters(),
betas=(0, 0.9), eps=1e-8, lr=1e-5
)
self.scheduler = utility.make_scheduler(args, self.optimizer)
def forward(self, fake, real): self.optimizer = utility.make_optimizer(optim_args, self.dis)
fake_detach = fake.detach()
def forward(self, fake, real):
# updating discriminator...
self.loss = 0 self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k): for _ in range(self.gan_k):
self.optimizer.zero_grad() self.optimizer.zero_grad()
d_fake = self.discriminator(fake_detach) # d: B x 1 tensor
d_real = self.discriminator(real) d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type == 'GAN': if self.gan_type == 'GAN':
label_fake = torch.zeros_like(d_fake) loss_d = self.bce(d_real, d_fake)
label_real = torch.ones_like(d_real)
loss_d \
= F.binary_cross_entropy_with_logits(d_fake, label_fake) \
+ F.binary_cross_entropy_with_logits(d_real, label_real)
elif self.gan_type.find('WGAN') >= 0: elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean() loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0: if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True hat.requires_grad = True
d_hat = self.discriminator(hat) d_hat = self.dis(hat)
gradients = torch.autograd.grad( gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat, outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True retain_graph=True, create_graph=True, only_inputs=True
...@@ -51,35 +59,53 @@ class Adversarial(nn.Module): ...@@ -51,35 +59,53 @@ class Adversarial(nn.Module):
gradient_norm = gradients.norm(2, dim=1) gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update # Discriminator update
self.loss += loss_d.item() self.loss += loss_d.item()
loss_d.backward() loss_d.backward(retain_graph=retain_graph)
self.optimizer.step() self.optimizer.step()
if self.gan_type == 'WGAN': if self.gan_type == 'WGAN':
for p in self.discriminator.parameters(): for p in self.dis.parameters():
p.data.clamp_(-1, 1) p.data.clamp_(-1, 1)
self.loss /= self.gan_k self.loss /= self.gan_k
d_fake_for_g = self.discriminator(fake) # updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type == 'GAN': if self.gan_type == 'GAN':
loss_g = F.binary_cross_entropy_with_logits( label_real = torch.ones_like(d_fake_bp)
d_fake_for_g, label_real loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
)
elif self.gan_type.find('WGAN') >= 0: elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_for_g.mean() loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
loss_g = self.bce(better_fake, better_real)
# Generator loss # Generator loss
return loss_g return loss_g
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
state_discriminator = self.discriminator.state_dict(*args, **kwargs) state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict() state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer) return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references # Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR # OR
......
...@@ -3,7 +3,10 @@ from model import common ...@@ -3,7 +3,10 @@ from model import common
import torch.nn as nn import torch.nn as nn
class Discriminator(nn.Module): class Discriminator(nn.Module):
def __init__(self, args, gan_type='GAN'): '''
output is not normalized
'''
def __init__(self, args):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
in_channels = args.n_colors in_channels = args.n_colors
......
...@@ -10,9 +10,9 @@ class VGG(nn.Module): ...@@ -10,9 +10,9 @@ class VGG(nn.Module):
super(VGG, self).__init__() super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features] modules = [m for m in vgg_features]
if conv_index == '22': if conv_index.find('22') >= 0:
self.vgg = nn.Sequential(*modules[:8]) self.vgg = nn.Sequential(*modules[:8])
elif conv_index == '54': elif conv_index.find('54') >= 0:
self.vgg = nn.Sequential(*modules[:35]) self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406) vgg_mean = (0.485, 0.456, 0.406)
......
...@@ -103,9 +103,7 @@ parser.add_argument('--gan_k', type=int, default=1, ...@@ -103,9 +103,7 @@ parser.add_argument('--gan_k', type=int, default=1,
# Optimization specifications # Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4, parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate') help='learning rate')
parser.add_argument('--lr_decay', type=int, default=200, parser.add_argument('--decay', type=str, default='200',
help='learning rate decay per N epochs')
parser.add_argument('--decay_type', type=str, default='step',
help='learning rate decay type') help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5, parser.add_argument('--gamma', type=float, default=0.5,
help='learning rate decay factor for step decay') help='learning rate decay factor for step decay')
...@@ -114,7 +112,7 @@ parser.add_argument('--optimizer', default='ADAM', ...@@ -114,7 +112,7 @@ parser.add_argument('--optimizer', default='ADAM',
help='optimizer to use (SGD | ADAM | RMSprop)') help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9, parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum') help='SGD momentum')
parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta') help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8, parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability') help='ADAM epsilon for numerical stability')
......
...@@ -4,7 +4,7 @@ def set_template(args): ...@@ -4,7 +4,7 @@ def set_template(args):
args.data_train = 'DIV2K_jpeg' args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg' args.data_test = 'DIV2K_jpeg'
args.epochs = 200 args.epochs = 200
args.lr_decay = 100 args.decay = '100'
if args.template.find('EDSR_paper') >= 0: if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR' args.model = 'EDSR'
...@@ -26,7 +26,7 @@ def set_template(args): ...@@ -26,7 +26,7 @@ def set_template(args):
args.batch_size = 20 args.batch_size = 20
args.epochs = 1000 args.epochs = 1000
args.lr_decay = 500 args.decay = '500'
args.gamma = 0.1 args.gamma = 0.1
args.weight_decay = 1e-4 args.weight_decay = 1e-4
...@@ -35,7 +35,7 @@ def set_template(args): ...@@ -35,7 +35,7 @@ def set_template(args):
if args.template.find('GAN') >= 0: if args.template.find('GAN') >= 0:
args.epochs = 200 args.epochs = 200
args.lr = 5e-5 args.lr = 5e-5
args.lr_decay = 150 args.decay = '150'
if args.template.find('RCAN') >= 0: if args.template.find('RCAN') >= 0:
args.model = 'RCAN' args.model = 'RCAN'
......
...@@ -19,21 +19,17 @@ class Trainer(): ...@@ -19,21 +19,17 @@ class Trainer():
self.model = my_model self.model = my_model
self.loss = my_loss self.loss = my_loss
self.optimizer = utility.make_optimizer(args, self.model) self.optimizer = utility.make_optimizer(args, self.model)
self.scheduler = utility.make_scheduler(args, self.optimizer)
if self.args.load != '': if self.args.load != '':
self.optimizer.load_state_dict( self.optimizer.load(ckp.dir, epoch=len(ckp.log))
torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
)
for _ in range(len(ckp.log)): self.scheduler.step()
self.error_last = 1e8 self.error_last = 1e8
def train(self): def train(self):
self.scheduler.step() self.optimizer.schedule()
self.loss.step() self.loss.step()
epoch = self.scheduler.last_epoch + 1 epoch = self.optimizer.get_last_epoch() + 1
lr = self.scheduler.get_lr()[0] lr = self.optimizer.get_lr()
self.ckp.write_log( self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
...@@ -76,7 +72,7 @@ class Trainer(): ...@@ -76,7 +72,7 @@ class Trainer():
def test(self): def test(self):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
epoch = self.scheduler.last_epoch + 1 epoch = self.optimizer.get_last_epoch() + 1
self.ckp.write_log('\nEvaluation:') self.ckp.write_log('\nEvaluation:')
self.ckp.add_log( self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale)) torch.zeros(1, len(self.loader_test), len(self.scale))
...@@ -118,7 +114,9 @@ class Trainer(): ...@@ -118,7 +114,9 @@ class Trainer():
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...') self.ckp.write_log('Saving...')
if self.args.save_results: self.ckp.end_background() if self.args.save_results:
self.ckp.end_background()
if not self.args.test_only: if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
...@@ -141,6 +139,6 @@ class Trainer(): ...@@ -141,6 +139,6 @@ class Trainer():
self.test() self.test()
return True return True
else: else:
epoch = self.scheduler.last_epoch + 1 epoch = self.optimizer.get_last_epoch() + 1
return epoch >= self.args.epochs return epoch >= self.args.epochs
...@@ -88,11 +88,8 @@ class checkpoint(): ...@@ -88,11 +88,8 @@ class checkpoint():
trainer.loss.plot_loss(self.dir, epoch) trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch) self.plot_psnr(epoch)
trainer.optimizer.save(self.dir)
torch.save(self.log, self.get_path('psnr_log.pt')) torch.save(self.log, self.get_path('psnr_log.pt'))
torch.save(
trainer.optimizer.state_dict(),
self.get_path('optimizer.pt')
)
def add_log(self, log): def add_log(self, log):
self.log = torch.cat([self.log, log]) self.log = torch.cat([self.log, log])
...@@ -183,35 +180,58 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None): ...@@ -183,35 +180,58 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
return -10 * math.log10(mse) return -10 * math.log10(mse)
def make_optimizer(args, my_model): def make_optimizer(args, target):
trainable = filter(lambda x: x.requires_grad, my_model.parameters()) '''
make optimizer and scheduler together
'''
# optimizer
trainable = filter(lambda x: x.requires_grad, target.parameters())
kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'SGD': if args.optimizer == 'SGD':
optimizer_function = optim.SGD optimizer_class = optim.SGD
kwargs = {'momentum': args.momentum} kwargs_optimizer['momentum'] = args.momentum
elif args.optimizer == 'ADAM': elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam optimizer_class = optim.Adam
kwargs = { kwargs_optimizer['betas'] = args.betas
'betas': args.beta, kwargs_optimizer['eps'] = args.epsilon
'eps': args.epsilon
}
elif args.optimizer == 'RMSprop': elif args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop optimizer_class = optim.RMSprop
kwargs = {'eps': args.epsilon} kwargs_optimizer['eps'] = args.epsilon
# scheduler
milestones = list(map(lambda x: int(x), args.decay.split('-')))
kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
scheduler_class = lrs.MultiStepLR
class CustomOptimizer(optimizer_class):
def __init__(self, *args, **kwargs):
super(CustomOptimizer, self).__init__(*args, **kwargs)
def _register_scheduler(self, scheduler_class, **kwargs):
self.scheduler = scheduler_class(self, **kwargs)
def save(self, save_dir):
torch.save(self.state_dict(), self.get_dir(save_dir))
def load(self, load_dir, epoch=1):
self.load_state_dict(torch.load(self.get_dir(load_dir)))
if epoch > 1:
for _ in range(epoch): self.scheduler.step()
def get_dir(self, dir_path):
return os.path.join(dir_path, 'optimizer.pt')
kwargs['lr'] = args.lr def schedule(self):
kwargs['weight_decay'] = args.weight_decay self.scheduler.step()
return optimizer_function(trainable, **kwargs) def get_lr(self):
return self.scheduler.get_lr()[0]
def make_scheduler(args, my_optimizer): def get_last_epoch(self):
if args.decay_type == 'step': return self.scheduler.last_epoch
scheduler_function = lrs.StepLR
kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma}
elif args.decay_type.find('step') >= 0:
scheduler_function = lrs.MultiStepLR
milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:]))
kwarg = {'milestones': milestones, 'gamma': args.gamma}
return scheduler_function(my_optimizer, **kwargs) optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment