From e2f7393e4d63fdae5fde27dc9a35f8b6b4f20ac7 Mon Sep 17 00:00:00 2001 From: Sanghyun Son <thstkdgus35@snu.ac.kr> Date: Thu, 18 Oct 2018 11:14:58 +0900 Subject: [PATCH] minor style change --- src/data/srdata.py | 4 ++-- src/loss/__init__.py | 2 +- src/loss/adversarial.py | 1 - src/loss/vgg.py | 1 - src/model/__init__.py | 1 - src/model/common.py | 2 -- src/option.py | 8 +++---- src/trainer.py | 2 +- src/utility.py | 46 ++++++++++++++++------------------------- 9 files changed, 25 insertions(+), 42 deletions(-) diff --git a/src/data/srdata.py b/src/data/srdata.py index 97723cf..a7c9a94 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -174,8 +174,8 @@ class SRData(data.Dataset): hr = imageio.imread(f_hr) lr = imageio.imread(f_lr) elif self.args.ext.find('sep') >= 0: - with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image'] - with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image'] + with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] + with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] return lr, hr, filename diff --git a/src/loss/__init__.py b/src/loss/__init__.py index 27c2e6b..6d7c21e 100644 --- a/src/loss/__init__.py +++ b/src/loss/__init__.py @@ -64,7 +64,7 @@ class Loss(nn.modules.loss._Loss): self.loss_module, range(args.n_GPUs) ) - if args.load != '.': self.load(ckp.dir, cpu=args.cpu) + if args.load != '': self.load(ckp.dir, cpu=args.cpu) def forward(self, sr, hr): losses = [] diff --git a/src/loss/adversarial.py b/src/loss/adversarial.py index 57275df..c4b7a4a 100644 --- a/src/loss/adversarial.py +++ b/src/loss/adversarial.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torch.autograd import Variable class Adversarial(nn.Module): def __init__(self, args, gan_type): diff --git a/src/loss/vgg.py b/src/loss/vgg.py index 78a8c3b..a0167f5 100644 --- a/src/loss/vgg.py +++ b/src/loss/vgg.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models -from torch.autograd import Variable class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): diff --git a/src/model/__init__.py b/src/model/__init__.py index cdb6fd8..a2cc30d 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -4,7 +4,6 @@ from importlib import import_module import torch import torch.nn as nn import torch.utils.model_zoo -from torch.autograd import Variable class Model(nn.Module): def __init__(self, args, ckp): diff --git a/src/model/common.py b/src/model/common.py index 79d0a0e..74ffa37 100644 --- a/src/model/common.py +++ b/src/model/common.py @@ -4,8 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable - def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, diff --git a/src/option.py b/src/option.py index afe2227..6729330 100644 --- a/src/option.py +++ b/src/option.py @@ -114,10 +114,8 @@ parser.add_argument('--optimizer', default='ADAM', help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') -parser.add_argument('--beta1', type=float, default=0.9, - help='ADAM beta1') -parser.add_argument('--beta2', type=float, default=0.999, - help='ADAM beta2') +parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), + help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, @@ -134,7 +132,7 @@ parser.add_argument('--skip_threshold', type=float, default='1e8', # Log specifications parser.add_argument('--save', type=str, default='test', help='file name to save') -parser.add_argument('--load', type=str, default='.', +parser.add_argument('--load', type=str, default='', help='file name to load') parser.add_argument('--resume', type=int, default=0, help='resume from specific checkpoint') diff --git a/src/trainer.py b/src/trainer.py index 77ae2de..a05fa0a 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -21,7 +21,7 @@ class Trainer(): self.optimizer = utility.make_optimizer(args, self.model) self.scheduler = utility.make_scheduler(args, self.optimizer) - if self.args.load != '.': + if self.args.load != '': self.optimizer.load_state_dict( torch.load(os.path.join(ckp.dir, 'optimizer.pt')) ) diff --git a/src/utility.py b/src/utility.py index 866c737..25aabd9 100644 --- a/src/utility.py +++ b/src/utility.py @@ -48,20 +48,21 @@ class checkpoint(): self.log = torch.Tensor() now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') - if args.load == '.': - if args.save == '.': args.save = now + if not args.load: + if not args.save: + args.save = now self.dir = os.path.join('..', 'experiment', args.save) else: self.dir = os.path.join('..', 'experiment', args.load) - if not os.path.exists(self.dir): - args.load = '.' - else: + if os.path.exists(self.dir): self.log = torch.load(self.get_path('psnr_log.pt')) print('Continue from epoch {}...'.format(len(self.log))) + else: + args.load = '' if args.reset: os.system('rm -rf ' + self.dir) - args.load = '.' + args.load = '' os.makedirs(self.dir, exist_ok=True) os.makedirs(self.get_path('model'), exist_ok=True) @@ -171,16 +172,13 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None): if dataset and dataset.dataset.benchmark: shave = scale if diff.size(1) > 1: - convert = diff.new(1, 3, 1, 1) - convert[0, 0, 0, 0] = 65.738 - convert[0, 1, 0, 0] = 129.057 - convert[0, 2, 0, 0] = 25.064 - diff *= (convert / 256) - diff = diff.sum(dim=1, keepdim=True) + gray_coeffs = [65.738, 129.057, 25.064] + convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 + diff = diff.mul(convert).sum(dim=1) else: shave = scale + 6 - valid = diff[:, :, shave:-shave, shave:-shave] + valid = diff[..., shave:-shave, shave:-shave] mse = valid.pow(2).mean() return -10 * math.log10(mse) @@ -194,7 +192,7 @@ def make_optimizer(args, my_model): elif args.optimizer == 'ADAM': optimizer_function = optim.Adam kwargs = { - 'betas': (args.beta1, args.beta2), + 'betas': args.beta, 'eps': args.epsilon } elif args.optimizer == 'RMSprop': @@ -208,20 +206,12 @@ def make_optimizer(args, my_model): def make_scheduler(args, my_optimizer): if args.decay_type == 'step': - scheduler = lrs.StepLR( - my_optimizer, - step_size=args.lr_decay, - gamma=args.gamma - ) + scheduler_function = lrs.StepLR + kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma} elif args.decay_type.find('step') >= 0: - milestones = args.decay_type.split('_') - milestones.pop(0) - milestones = list(map(lambda x: int(x), milestones)) - scheduler = lrs.MultiStepLR( - my_optimizer, - milestones=milestones, - gamma=args.gamma - ) + scheduler_function = lrs.MultiStepLR + milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:])) + kwarg = {'milestones': milestones, 'gamma': args.gamma} - return scheduler + return scheduler_function(my_optimizer, **kwargs) -- GitLab