diff --git a/src/data/srdata.py b/src/data/srdata.py index 97723cfd37161e3e006a58127b2654b21864b11a..a7c9a947fdf492f640fc329980ff43c4efcdb8c0 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -174,8 +174,8 @@ class SRData(data.Dataset): hr = imageio.imread(f_hr) lr = imageio.imread(f_lr) elif self.args.ext.find('sep') >= 0: - with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image'] - with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image'] + with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] + with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] return lr, hr, filename diff --git a/src/loss/__init__.py b/src/loss/__init__.py index 27c2e6b828db5e97bad29790bb3416204b5dc9fa..6d7c21e2f9dbf94a20fb6c74ede439711c44a15c 100644 --- a/src/loss/__init__.py +++ b/src/loss/__init__.py @@ -64,7 +64,7 @@ class Loss(nn.modules.loss._Loss): self.loss_module, range(args.n_GPUs) ) - if args.load != '.': self.load(ckp.dir, cpu=args.cpu) + if args.load != '': self.load(ckp.dir, cpu=args.cpu) def forward(self, sr, hr): losses = [] diff --git a/src/loss/adversarial.py b/src/loss/adversarial.py index 57275df7532e650c423f24c8b3c7a86e07f41f3c..c4b7a4a90e234107130fe2f062a4c1b9f3429216 100644 --- a/src/loss/adversarial.py +++ b/src/loss/adversarial.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torch.autograd import Variable class Adversarial(nn.Module): def __init__(self, args, gan_type): diff --git a/src/loss/vgg.py b/src/loss/vgg.py index 78a8c3bac702673398cd6bdc1a66ca75caf5b099..a0167f5c8598abf1ad99ec81f32a6e13fefacbc1 100644 --- a/src/loss/vgg.py +++ b/src/loss/vgg.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models -from torch.autograd import Variable class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): diff --git a/src/model/__init__.py b/src/model/__init__.py index cdb6fd8bdface47ad2107453ccdb9f2d0c5b4477..a2cc30d63fd417b965be92135e97cdfe7ee6cda8 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -4,7 +4,6 @@ from importlib import import_module import torch import torch.nn as nn import torch.utils.model_zoo -from torch.autograd import Variable class Model(nn.Module): def __init__(self, args, ckp): diff --git a/src/model/common.py b/src/model/common.py index 79d0a0ec35b836cd7a0c81428af0e111998ce83f..74ffa371a680e80bcea79853f4d948c433c72721 100644 --- a/src/model/common.py +++ b/src/model/common.py @@ -4,8 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable - def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, diff --git a/src/option.py b/src/option.py index afe22277d7d1bf2a3f550c3ff5a061d2d795de6e..672933096c14ddb4bf3b05c14382f0d4cdd206e5 100644 --- a/src/option.py +++ b/src/option.py @@ -114,10 +114,8 @@ parser.add_argument('--optimizer', default='ADAM', help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') -parser.add_argument('--beta1', type=float, default=0.9, - help='ADAM beta1') -parser.add_argument('--beta2', type=float, default=0.999, - help='ADAM beta2') +parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), + help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, @@ -134,7 +132,7 @@ parser.add_argument('--skip_threshold', type=float, default='1e8', # Log specifications parser.add_argument('--save', type=str, default='test', help='file name to save') -parser.add_argument('--load', type=str, default='.', +parser.add_argument('--load', type=str, default='', help='file name to load') parser.add_argument('--resume', type=int, default=0, help='resume from specific checkpoint') diff --git a/src/trainer.py b/src/trainer.py index 77ae2de0c9d8c965ffdd16554a9957dcdc24774c..a05fa0a515b2c309e0a13a9606db54c5e9fd7527 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -21,7 +21,7 @@ class Trainer(): self.optimizer = utility.make_optimizer(args, self.model) self.scheduler = utility.make_scheduler(args, self.optimizer) - if self.args.load != '.': + if self.args.load != '': self.optimizer.load_state_dict( torch.load(os.path.join(ckp.dir, 'optimizer.pt')) ) diff --git a/src/utility.py b/src/utility.py index 866c737b376c22a7c76774b9e356c87fce4e86fd..25aabd9f554fce1cbd6d830d8985ed9c6dad5276 100644 --- a/src/utility.py +++ b/src/utility.py @@ -48,20 +48,21 @@ class checkpoint(): self.log = torch.Tensor() now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') - if args.load == '.': - if args.save == '.': args.save = now + if not args.load: + if not args.save: + args.save = now self.dir = os.path.join('..', 'experiment', args.save) else: self.dir = os.path.join('..', 'experiment', args.load) - if not os.path.exists(self.dir): - args.load = '.' - else: + if os.path.exists(self.dir): self.log = torch.load(self.get_path('psnr_log.pt')) print('Continue from epoch {}...'.format(len(self.log))) + else: + args.load = '' if args.reset: os.system('rm -rf ' + self.dir) - args.load = '.' + args.load = '' os.makedirs(self.dir, exist_ok=True) os.makedirs(self.get_path('model'), exist_ok=True) @@ -171,16 +172,13 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None): if dataset and dataset.dataset.benchmark: shave = scale if diff.size(1) > 1: - convert = diff.new(1, 3, 1, 1) - convert[0, 0, 0, 0] = 65.738 - convert[0, 1, 0, 0] = 129.057 - convert[0, 2, 0, 0] = 25.064 - diff *= (convert / 256) - diff = diff.sum(dim=1, keepdim=True) + gray_coeffs = [65.738, 129.057, 25.064] + convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 + diff = diff.mul(convert).sum(dim=1) else: shave = scale + 6 - valid = diff[:, :, shave:-shave, shave:-shave] + valid = diff[..., shave:-shave, shave:-shave] mse = valid.pow(2).mean() return -10 * math.log10(mse) @@ -194,7 +192,7 @@ def make_optimizer(args, my_model): elif args.optimizer == 'ADAM': optimizer_function = optim.Adam kwargs = { - 'betas': (args.beta1, args.beta2), + 'betas': args.beta, 'eps': args.epsilon } elif args.optimizer == 'RMSprop': @@ -208,20 +206,12 @@ def make_optimizer(args, my_model): def make_scheduler(args, my_optimizer): if args.decay_type == 'step': - scheduler = lrs.StepLR( - my_optimizer, - step_size=args.lr_decay, - gamma=args.gamma - ) + scheduler_function = lrs.StepLR + kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma} elif args.decay_type.find('step') >= 0: - milestones = args.decay_type.split('_') - milestones.pop(0) - milestones = list(map(lambda x: int(x), milestones)) - scheduler = lrs.MultiStepLR( - my_optimizer, - milestones=milestones, - gamma=args.gamma - ) + scheduler_function = lrs.MultiStepLR + milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:])) + kwarg = {'milestones': milestones, 'gamma': args.gamma} - return scheduler + return scheduler_function(my_optimizer, **kwargs)