Skip to content
Snippets Groups Projects
Commit e2f7393e authored by Sanghyun Son's avatar Sanghyun Son
Browse files

minor style change

parent ea971d90
Branches master
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -174,8 +174,8 @@ class SRData(data.Dataset): ...@@ -174,8 +174,8 @@ class SRData(data.Dataset):
hr = imageio.imread(f_hr) hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr) lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0: elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image'] with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image']
with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image'] with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image']
return lr, hr, filename return lr, hr, filename
......
...@@ -64,7 +64,7 @@ class Loss(nn.modules.loss._Loss): ...@@ -64,7 +64,7 @@ class Loss(nn.modules.loss._Loss):
self.loss_module, range(args.n_GPUs) self.loss_module, range(args.n_GPUs)
) )
if args.load != '.': self.load(ckp.dir, cpu=args.cpu) if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr): def forward(self, sr, hr):
losses = [] losses = []
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.autograd import Variable
class Adversarial(nn.Module): class Adversarial(nn.Module):
def __init__(self, args, gan_type): def __init__(self, args, gan_type):
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.models as models import torchvision.models as models
from torch.autograd import Variable
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1): def __init__(self, conv_index, rgb_range=1):
......
...@@ -4,7 +4,6 @@ from importlib import import_module ...@@ -4,7 +4,6 @@ from importlib import import_module
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo import torch.utils.model_zoo
from torch.autograd import Variable
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args, ckp): def __init__(self, args, ckp):
......
...@@ -4,8 +4,6 @@ import torch ...@@ -4,8 +4,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
def default_conv(in_channels, out_channels, kernel_size, bias=True): def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d( return nn.Conv2d(
in_channels, out_channels, kernel_size, in_channels, out_channels, kernel_size,
......
...@@ -114,10 +114,8 @@ parser.add_argument('--optimizer', default='ADAM', ...@@ -114,10 +114,8 @@ parser.add_argument('--optimizer', default='ADAM',
help='optimizer to use (SGD | ADAM | RMSprop)') help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9, parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum') help='SGD momentum')
parser.add_argument('--beta1', type=float, default=0.9, parser.add_argument('--beta', type=tuple, default=(0.9, 0.999),
help='ADAM beta1') help='ADAM beta')
parser.add_argument('--beta2', type=float, default=0.999,
help='ADAM beta2')
parser.add_argument('--epsilon', type=float, default=1e-8, parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability') help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0, parser.add_argument('--weight_decay', type=float, default=0,
...@@ -134,7 +132,7 @@ parser.add_argument('--skip_threshold', type=float, default='1e8', ...@@ -134,7 +132,7 @@ parser.add_argument('--skip_threshold', type=float, default='1e8',
# Log specifications # Log specifications
parser.add_argument('--save', type=str, default='test', parser.add_argument('--save', type=str, default='test',
help='file name to save') help='file name to save')
parser.add_argument('--load', type=str, default='.', parser.add_argument('--load', type=str, default='',
help='file name to load') help='file name to load')
parser.add_argument('--resume', type=int, default=0, parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint') help='resume from specific checkpoint')
......
...@@ -21,7 +21,7 @@ class Trainer(): ...@@ -21,7 +21,7 @@ class Trainer():
self.optimizer = utility.make_optimizer(args, self.model) self.optimizer = utility.make_optimizer(args, self.model)
self.scheduler = utility.make_scheduler(args, self.optimizer) self.scheduler = utility.make_scheduler(args, self.optimizer)
if self.args.load != '.': if self.args.load != '':
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(ckp.dir, 'optimizer.pt')) torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
) )
......
...@@ -48,20 +48,21 @@ class checkpoint(): ...@@ -48,20 +48,21 @@ class checkpoint():
self.log = torch.Tensor() self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if args.load == '.': if not args.load:
if args.save == '.': args.save = now if not args.save:
args.save = now
self.dir = os.path.join('..', 'experiment', args.save) self.dir = os.path.join('..', 'experiment', args.save)
else: else:
self.dir = os.path.join('..', 'experiment', args.load) self.dir = os.path.join('..', 'experiment', args.load)
if not os.path.exists(self.dir): if os.path.exists(self.dir):
args.load = '.'
else:
self.log = torch.load(self.get_path('psnr_log.pt')) self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log))) print('Continue from epoch {}...'.format(len(self.log)))
else:
args.load = ''
if args.reset: if args.reset:
os.system('rm -rf ' + self.dir) os.system('rm -rf ' + self.dir)
args.load = '.' args.load = ''
os.makedirs(self.dir, exist_ok=True) os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True) os.makedirs(self.get_path('model'), exist_ok=True)
...@@ -171,16 +172,13 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None): ...@@ -171,16 +172,13 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if dataset and dataset.dataset.benchmark: if dataset and dataset.dataset.benchmark:
shave = scale shave = scale
if diff.size(1) > 1: if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1) gray_coeffs = [65.738, 129.057, 25.064]
convert[0, 0, 0, 0] = 65.738 convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
convert[0, 1, 0, 0] = 129.057 diff = diff.mul(convert).sum(dim=1)
convert[0, 2, 0, 0] = 25.064
diff *= (convert / 256)
diff = diff.sum(dim=1, keepdim=True)
else: else:
shave = scale + 6 shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave] valid = diff[..., shave:-shave, shave:-shave]
mse = valid.pow(2).mean() mse = valid.pow(2).mean()
return -10 * math.log10(mse) return -10 * math.log10(mse)
...@@ -194,7 +192,7 @@ def make_optimizer(args, my_model): ...@@ -194,7 +192,7 @@ def make_optimizer(args, my_model):
elif args.optimizer == 'ADAM': elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam optimizer_function = optim.Adam
kwargs = { kwargs = {
'betas': (args.beta1, args.beta2), 'betas': args.beta,
'eps': args.epsilon 'eps': args.epsilon
} }
elif args.optimizer == 'RMSprop': elif args.optimizer == 'RMSprop':
...@@ -208,20 +206,12 @@ def make_optimizer(args, my_model): ...@@ -208,20 +206,12 @@ def make_optimizer(args, my_model):
def make_scheduler(args, my_optimizer): def make_scheduler(args, my_optimizer):
if args.decay_type == 'step': if args.decay_type == 'step':
scheduler = lrs.StepLR( scheduler_function = lrs.StepLR
my_optimizer, kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma}
step_size=args.lr_decay,
gamma=args.gamma
)
elif args.decay_type.find('step') >= 0: elif args.decay_type.find('step') >= 0:
milestones = args.decay_type.split('_') scheduler_function = lrs.MultiStepLR
milestones.pop(0) milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:]))
milestones = list(map(lambda x: int(x), milestones)) kwarg = {'milestones': milestones, 'gamma': args.gamma}
scheduler = lrs.MultiStepLR(
my_optimizer,
milestones=milestones,
gamma=args.gamma
)
return scheduler return scheduler_function(my_optimizer, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment