Skip to content
Snippets Groups Projects
Commit e2f7393e authored by Sanghyun Son's avatar Sanghyun Son
Browse files

minor style change

parent ea971d90
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
......@@ -174,8 +174,8 @@ class SRData(data.Dataset):
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image']
with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image']
with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image']
with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image']
return lr, hr, filename
......
......@@ -64,7 +64,7 @@ class Loss(nn.modules.loss._Loss):
self.loss_module, range(args.n_GPUs)
)
if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr):
losses = []
......
......@@ -6,7 +6,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
......
......@@ -4,7 +4,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1):
......
......@@ -4,7 +4,6 @@ from importlib import import_module
import torch
import torch.nn as nn
import torch.utils.model_zoo
from torch.autograd import Variable
class Model(nn.Module):
def __init__(self, args, ckp):
......
......@@ -4,8 +4,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
......
......@@ -114,10 +114,8 @@ parser.add_argument('--optimizer', default='ADAM',
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--beta1', type=float, default=0.9,
help='ADAM beta1')
parser.add_argument('--beta2', type=float, default=0.999,
help='ADAM beta2')
parser.add_argument('--beta', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
......@@ -134,7 +132,7 @@ parser.add_argument('--skip_threshold', type=float, default='1e8',
# Log specifications
parser.add_argument('--save', type=str, default='test',
help='file name to save')
parser.add_argument('--load', type=str, default='.',
parser.add_argument('--load', type=str, default='',
help='file name to load')
parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint')
......
......@@ -21,7 +21,7 @@ class Trainer():
self.optimizer = utility.make_optimizer(args, self.model)
self.scheduler = utility.make_scheduler(args, self.optimizer)
if self.args.load != '.':
if self.args.load != '':
self.optimizer.load_state_dict(
torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
)
......
......@@ -48,20 +48,21 @@ class checkpoint():
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if args.load == '.':
if args.save == '.': args.save = now
if not args.load:
if not args.save:
args.save = now
self.dir = os.path.join('..', 'experiment', args.save)
else:
self.dir = os.path.join('..', 'experiment', args.load)
if not os.path.exists(self.dir):
args.load = '.'
else:
if os.path.exists(self.dir):
self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log)))
else:
args.load = ''
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = '.'
args.load = ''
os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True)
......@@ -171,16 +172,13 @@ def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if dataset and dataset.dataset.benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff *= (convert / 256)
diff = diff.sum(dim=1, keepdim=True)
gray_coeffs = [65.738, 129.057, 25.064]
convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
diff = diff.mul(convert).sum(dim=1)
else:
shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave]
valid = diff[..., shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
......@@ -194,7 +192,7 @@ def make_optimizer(args, my_model):
elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam
kwargs = {
'betas': (args.beta1, args.beta2),
'betas': args.beta,
'eps': args.epsilon
}
elif args.optimizer == 'RMSprop':
......@@ -208,20 +206,12 @@ def make_optimizer(args, my_model):
def make_scheduler(args, my_optimizer):
if args.decay_type == 'step':
scheduler = lrs.StepLR(
my_optimizer,
step_size=args.lr_decay,
gamma=args.gamma
)
scheduler_function = lrs.StepLR
kwargs = {'step_size': args.lr_decay, 'gamma': args.gamma}
elif args.decay_type.find('step') >= 0:
milestones = args.decay_type.split('_')
milestones.pop(0)
milestones = list(map(lambda x: int(x), milestones))
scheduler = lrs.MultiStepLR(
my_optimizer,
milestones=milestones,
gamma=args.gamma
)
scheduler_function = lrs.MultiStepLR
milestones = list(map(lambda x: int(x), args.decay_type.split('-')[1:]))
kwarg = {'milestones': milestones, 'gamma': args.gamma}
return scheduler
return scheduler_function(my_optimizer, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment