From e2f7393e4d63fdae5fde27dc9a35f8b6b4f20ac7 Mon Sep 17 00:00:00 2001
From: Sanghyun Son <thstkdgus35@snu.ac.kr>
Date: Thu, 18 Oct 2018 11:14:58 +0900
Subject: [PATCH] minor style change

---
 src/data/srdata.py      |  4 ++--
 src/loss/__init__.py    |  2 +-
 src/loss/adversarial.py |  1 -
 src/loss/vgg.py         |  1 -
 src/model/__init__.py   |  1 -
 src/model/common.py     |  2 --
 src/option.py           |  8 +++----
 src/trainer.py          |  2 +-
 src/utility.py          | 46 ++++++++++++++++-------------------------
 9 files changed, 25 insertions(+), 42 deletions(-)

diff --git a/src/data/srdata.py b/src/data/srdata.py
index 97723cf..a7c9a94 100644
--- a/src/data/srdata.py
+++ b/src/data/srdata.py
@@ -174,8 +174,8 @@ class SRData(data.Dataset):
                 hr = imageio.imread(f_hr)
                 lr = imageio.imread(f_lr)
             elif self.args.ext.find('sep') >= 0:
-                with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image']
-                with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image']
+                with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image']
+                with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image']
 
         return lr, hr, filename
 
diff --git a/src/loss/__init__.py b/src/loss/__init__.py
index 27c2e6b..6d7c21e 100644
--- a/src/loss/__init__.py
+++ b/src/loss/__init__.py
@@ -64,7 +64,7 @@ class Loss(nn.modules.loss._Loss):
                 self.loss_module, range(args.n_GPUs)
             )
 
-        if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
+        if args.load != '': self.load(ckp.dir, cpu=args.cpu)
 
     def forward(self, sr, hr):
         losses = []
diff --git a/src/loss/adversarial.py b/src/loss/adversarial.py
index 57275df..c4b7a4a 100644
--- a/src/loss/adversarial.py
+++ b/src/loss/adversarial.py
@@ -6,7 +6,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
-from torch.autograd import Variable
 
 class Adversarial(nn.Module):
     def __init__(self, args, gan_type):
diff --git a/src/loss/vgg.py b/src/loss/vgg.py
index 78a8c3b..a0167f5 100644
--- a/src/loss/vgg.py
+++ b/src/loss/vgg.py
@@ -4,7 +4,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchvision.models as models
-from torch.autograd import Variable
 
 class VGG(nn.Module):
     def __init__(self, conv_index, rgb_range=1):
diff --git a/src/model/__init__.py b/src/model/__init__.py
index cdb6fd8..a2cc30d 100644
--- a/src/model/__init__.py
+++ b/src/model/__init__.py
@@ -4,7 +4,6 @@ from importlib import import_module
 import torch
 import torch.nn as nn
 import torch.utils.model_zoo
-from torch.autograd import Variable
 
 class Model(nn.Module):
     def __init__(self, args, ckp):
diff --git a/src/model/common.py b/src/model/common.py
index 79d0a0e..74ffa37 100644
--- a/src/model/common.py
+++ b/src/model/common.py
@@ -4,8 +4,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from torch.autograd import Variable
-
 def default_conv(in_channels, out_channels, kernel_size, bias=True):
     return nn.Conv2d(
         in_channels, out_channels, kernel_size,
diff --git a/src/option.py b/src/option.py
index afe2227..6729330 100644
--- a/src/option.py
+++ b/src/option.py
@@ -114,10 +114,8 @@ parser.add_argument('--optimizer', default='ADAM',
                     help='optimizer to use (SGD | ADAM | RMSprop)')
 parser.add_argument('--momentum', type=float, default=0.9,
                     help='SGD momentum')
-parser.add_argument('--beta1', type=float, default=0.9,
-                    help='ADAM beta1')
-parser.add_argument('--beta2', type=float, default=0.999,
-                    help='ADAM beta2')
+parser.add_argument('--beta', type=tuple, default=(0.9, 0.999),
+                    help='ADAM beta')
 parser.add_argument('--epsilon', type=float, default=1e-8,
                     help='ADAM epsilon for numerical stability')
 parser.add_argument('--weight_decay', type=float, default=0,
@@ -134,7 +132,7 @@ parser.add_argument('--skip_threshold', type=float, default='1e8',
 # Log specifications
 parser.add_argument('--save', type=str, default='test',
                     help='file name to save')
-parser.add_argument('--load', type=str, default='.',
+parser.add_argument('--load', type=str, default='',
                     help='file name to load')
 parser.add_argument('--resume', type=int, default=0,
                     help='resume from specific checkpoint')
diff --git a/src/trainer.py b/src/trainer.py
index 77ae2de..a05fa0a 100644
--- a/src/trainer.py
+++ b/src/trainer.py
@@ -21,7 +21,7 @@ class Trainer():
         self.optimizer = utility.make_optimizer(args, self.model)
         self.scheduler = utility.make_scheduler(args, self.optimizer)
 
-        if self.args.load != '.':
+        if self.args.load != '':
             self.optimizer.load_state_dict(
                 torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
             )
diff --git a/src/utility.py b/src/utility.py
index 866c737..25aabd9 100644
--- a/src/utility.py
+++ b/src/utility.py
@@ -48,20 +48,21 @@ class checkpoint():
         self.log = torch.Tensor()
         now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
 
-        if args.load == '.':
-            if args.save == '.': args.save = now
+        if not args.load:
+            if not args.save:
+                args.save = now
             self.dir = os.path.join('..', 'experiment', args.save)
         else:
             self.dir = os.path.join('..', 'experiment', args.load)
-            if not os.path.exists(self.dir):
-                args.load = '.'
-            else:
+            if os.path.exists(self.dir):
                 self.log = torch.load(self.get_path('psnr_log.pt'))
                 print('Continue from epoch {}...'.format(len(self.log)))
+            else:
+                args.load = ''
 
         if args.reset:
             os.system('rm -rf ' + self.dir)
-            args.load = '.'
+            args.load = ''
 
         os.makedirs(self.dir, exist_ok=True)
         os.makedirs(self.get_path('model'), exist_ok=True)
@@ -171,16 +172,13 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
     if dataset and dataset.dataset.benchmark:
         shave = scale
         if diff.size(1) > 1:
-            convert = diff.new(1, 3, 1, 1)
-            convert[0, 0, 0, 0] = 65.738
-            convert[0, 1, 0, 0] = 129.057
-            convert[0, 2, 0, 0] = 25.064
-            diff *= (convert / 256)
-            diff = diff.sum(dim=1, keepdim=True)
+            gray_coeffs = [65.738, 129.057, 25.064]
+            convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
+            diff = diff.mul(convert).sum(dim=1)
     else:
         shave = scale + 6
 
-    valid = diff[:, :, shave:-shave, shave:-shave]
+    valid = diff[..., shave:-shave, shave:-shave]
     mse = valid.pow(2).mean()
 
     return -10 * math.log10(mse)
@@ -194,7 +192,7 @@ def make_optimizer(args, my_model):
     elif args.optimizer == 'ADAM':
         optimizer_function = optim.Adam
         kwargs = {
-            'betas': (args.beta1, args.beta2),
+            'betas': args.beta,
             'eps': args.epsilon
         }
     elif args.optimizer == 'RMSprop':
@@ -208,20 +206,12 @@ def make_optimizer(args, my_model):
 
 def make_scheduler(args, my_optimizer):
     if args.decay_type == 'step':
-        scheduler = lrs.StepLR(
-            my_optimizer,
-            step_size=args.lr_decay,
-            gamma=args.gamma
-        )
+        scheduler_function = lrs.StepLR
+        kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma}
     elif args.decay_type.find('step') >= 0:
-        milestones = args.decay_type.split('_')
-        milestones.pop(0)
-        milestones = list(map(lambda x: int(x), milestones))
-        scheduler = lrs.MultiStepLR(
-            my_optimizer,
-            milestones=milestones,
-            gamma=args.gamma
-        )
+        scheduler_function = lrs.MultiStepLR
+        milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:]))
+        kwarg = {'milestones': milestones, 'gamma': args.gamma}
 
-    return scheduler
+    return scheduler_function(my_optimizer, **kwargs)
 
-- 
GitLab