Skip to content
Snippets Groups Projects
Commit d99e5756 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

fix vdsr bugs

parent b1df88b5
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
......@@ -5,16 +5,24 @@ import skimage.color as sc
import torch
def get_patch(*args, patch_size=96, scale=1, multi_scale=False):
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
ih, iw = args[0].shape[:2]
p = scale if multi_scale else 1
if not input_large:
p = scale if multi else 1
tp = p * patch_size
ip = tp // scale
else:
tp = patch_size
ip = patch_size
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
......
import os
import glob
import random
import pickle
from data import common
import pickle
import numpy as np
import imageio
import torch
import torch.utils.data as data
......@@ -29,6 +30,7 @@ class SRData(data.Dataset):
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0:
......@@ -104,12 +106,10 @@ class SRData(data.Dataset):
return names_hr, names_lr
def _set_filesystem(self, dir_data):
bicubic_type = 'LR_bicubic' if not self.input_large else 'LR_bicubicL'
self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
if self.input_large:
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
else:
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.dir_lr = os.path.join(self.apath, bicubic_type)
self.ext = ('.png', '.png')
def _name_hrbin(self):
......@@ -145,17 +145,16 @@ class SRData(data.Dataset):
'image': imageio.imread(_l)
} for _l in l]
with open(f, 'wb') as _f: pickle.dump(b, _f)
return b
def __getitem__(self, idx):
lr, hr, filename = self._load_file(idx)
lr, hr = self.get_patch(lr, hr)
lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors)
lr_tensor, hr_tensor = common.np2Tensor(
lr, hr, rgb_range=self.args.rgb_range
)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return lr_tensor, hr_tensor, filename
return pair_t[0], pair_t[1], filename
def __len__(self):
if self.train:
......@@ -191,17 +190,15 @@ class SRData(data.Dataset):
def get_patch(self, lr, hr):
scale = self.scale[self.idx_scale]
multi_scale = len(self.scale) > 1
if self.train:
lr, hr = common.get_patch(
lr,
hr,
lr, hr,
patch_size=self.args.patch_size,
scale=scale,
multi_scale=multi_scale
multi=(len(self.scale) > 1),
input_large=self.input_large
)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
if not self.args.no_augment: lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
......@@ -209,5 +206,8 @@ class SRData(data.Dataset):
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)
......@@ -23,13 +23,9 @@
#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
# Standard benchmarks (Ex. EDSR_baseline_x4)
#python main.py --data_test Set5 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
......@@ -60,3 +56,4 @@ python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_s
#python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
# RCAN_BIX8_G10R20P48, input=48x48, output=384x384
#python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
from model import common
import torch.nn as nn
import torch.nn.init as init
url = {
'r20f64': ''
......@@ -20,18 +21,19 @@ class VDSR(nn.Module):
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
def basic_block(in_channels, out_channels, act):
return common.BasicBlock(
conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=act
)
# define body module
m_body = []
m_body.append(common.BasicBlock(
conv, args.n_colors, n_feats, kernel_size, bn=False
))
m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
for _ in range(n_resblocks - 2):
m_body.append(common.BasicBlock(
conv, n_feats, n_feats, kernel_size, bn=False
))
m_body.append(common.BasicBlock(
conv, n_feats, args.n_colors, kernel_size, bn=False, act=None
))
m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
m_body.append(basic_block(n_feats, args.n_colors, None))
self.body = nn.Sequential(*m_body)
def forward(self, x):
......
......@@ -122,11 +122,13 @@ parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e6',
parser.add_argument('--skip_threshold', type=float, default='1e8',
help='skipping batch that has large error')
# Log specifications
......
......@@ -44,3 +44,10 @@ def set_template(args):
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1
......@@ -5,6 +5,7 @@ from decimal import Decimal
import utility
import torch
import torch.nn.utils as utils
from tqdm import tqdm
class Trainer():
......@@ -49,13 +50,13 @@ class Trainer():
self.optimizer.zero_grad()
sr = self.model(lr, idx_scale)
loss = self.loss(sr, hr)
if loss.item() < self.args.skip_threshold * self.error_last:
loss.backward()
if self.args.gclip > 0:
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
self.optimizer.step()
else:
print('Skip this batch {}! (Loss: {})'.format(
batch + 1, loss.item()
))
timer_model.hold()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment