diff --git a/src/data/common.py b/src/data/common.py index abd744c058de6c35c3798112a562d3af34d8ff42..26170bcdf5f416c4d480f8b95015f50aee421ddf 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -5,16 +5,24 @@ import skimage.color as sc import torch -def get_patch(*args, patch_size=96, scale=1, multi_scale=False): +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): ih, iw = args[0].shape[:2] - p = scale if multi_scale else 1 - tp = p * patch_size - ip = tp // scale + if not input_large: + p = scale if multi else 1 + tp = p * patch_size + ip = tp // scale + else: + tp = patch_size + ip = patch_size ix = random.randrange(0, iw - ip + 1) iy = random.randrange(0, ih - ip + 1) - tx, ty = scale * ix, scale * iy + + if not input_large: + tx, ty = scale * ix, scale * iy + else: + tx, ty = ix, iy ret = [ args[0][iy:iy + ip, ix:ix + ip, :], diff --git a/src/data/srdata.py b/src/data/srdata.py index 2d678ccb74d05c560f47119afa8f22c269e3ba7b..f8489342a4a41204760f77ca6345f62dcb025982 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -1,11 +1,12 @@ import os import glob +import random +import pickle from data import common -import pickle + import numpy as np import imageio - import torch import torch.utils.data as data @@ -29,6 +30,7 @@ class SRData(data.Dataset): data_range = data_range[0] else: data_range = data_range[1] + self.begin, self.end = list(map(lambda x: int(x), data_range)) self._set_filesystem(args.dir_data) if args.ext.find('img') < 0: @@ -104,12 +106,10 @@ class SRData(data.Dataset): return names_hr, names_lr def _set_filesystem(self, dir_data): + bicubic_type = 'LR_bicubic' if not self.input_large else 'LR_bicubicL' self.apath = os.path.join(dir_data, self.name) self.dir_hr = os.path.join(self.apath, 'HR') - if self.input_large: - self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') - else: - self.dir_lr = os.path.join(self.apath, 'LR_bicubic') + self.dir_lr = os.path.join(self.apath, bicubic_type) self.ext = ('.png', '.png') def _name_hrbin(self): @@ -145,17 +145,16 @@ class SRData(data.Dataset): 'image': imageio.imread(_l) } for _l in l] with open(f, 'wb') as _f: pickle.dump(b, _f) + return b def __getitem__(self, idx): lr, hr, filename = self._load_file(idx) - lr, hr = self.get_patch(lr, hr) - lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors) - lr_tensor, hr_tensor = common.np2Tensor( - lr, hr, rgb_range=self.args.rgb_range - ) + pair = self.get_patch(lr, hr) + pair = common.set_channel(*pair, n_channels=self.args.n_colors) + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) - return lr_tensor, hr_tensor, filename + return pair_t[0], pair_t[1], filename def __len__(self): if self.train: @@ -191,17 +190,15 @@ class SRData(data.Dataset): def get_patch(self, lr, hr): scale = self.scale[self.idx_scale] - multi_scale = len(self.scale) > 1 if self.train: lr, hr = common.get_patch( - lr, - hr, + lr, hr, patch_size=self.args.patch_size, scale=scale, - multi_scale=multi_scale + multi=(len(self.scale) > 1), + input_large=self.input_large ) - if not self.args.no_augment: - lr, hr = common.augment(lr, hr) + if not self.args.no_augment: lr, hr = common.augment(lr, hr) else: ih, iw = lr.shape[:2] hr = hr[0:ih * scale, 0:iw * scale] @@ -209,5 +206,8 @@ class SRData(data.Dataset): return lr, hr def set_scale(self, idx_scale): - self.idx_scale = idx_scale + if not self.input_large: + self.idx_scale = idx_scale + else: + self.idx_scale = random.randint(0, len(self.scale) - 1) diff --git a/src/demo.sh b/src/demo.sh index 70f12af086037706e9262b051360c6236fa582a8..cc00a165b6b61edd2c9b01b3a670520987872d30 100644 --- a/src/demo.sh +++ b/src/demo.sh @@ -23,13 +23,9 @@ #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models # Standard benchmarks (Ex. EDSR_baseline_x4) -#python main.py --data_test Set5 --scale 4 --pre_train download --test_only --self_ensemble -#python main.py --data_test Set14 --scale 4 --pre_train download --test_only --self_ensemble -#python main.py --data_test B100 --scale 4 --pre_train download --test_only --self_ensemble -#python main.py --data_test Urban100 --scale 4 --pre_train download --test_only --self_ensemble -#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble -python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble +#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble #python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble #python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble #python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble @@ -60,3 +56,4 @@ python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_s #python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt # RCAN_BIX8_G10R20P48, input=48x48, output=384x384 #python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt + diff --git a/src/model/vdsr.py b/src/model/vdsr.py index 46ca41d75bacb51e63afbbc87b6ccdd9472255a6..01d07c5513b7c3a2fd71acd2c0e4250d5f3242cf 100644 --- a/src/model/vdsr.py +++ b/src/model/vdsr.py @@ -1,6 +1,7 @@ from model import common import torch.nn as nn +import torch.nn.init as init url = { 'r20f64': '' @@ -20,18 +21,19 @@ class VDSR(nn.Module): self.sub_mean = common.MeanShift(args.rgb_range) self.add_mean = common.MeanShift(args.rgb_range, sign=1) + def basic_block(in_channels, out_channels, act): + return common.BasicBlock( + conv, in_channels, out_channels, kernel_size, + bias=True, bn=False, act=act + ) + # define body module m_body = [] - m_body.append(common.BasicBlock( - conv, args.n_colors, n_feats, kernel_size, bn=False - )) + m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) for _ in range(n_resblocks - 2): - m_body.append(common.BasicBlock( - conv, n_feats, n_feats, kernel_size, bn=False - )) - m_body.append(common.BasicBlock( - conv, n_feats, args.n_colors, kernel_size, bn=False, act=None - )) + m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) + m_body.append(basic_block(n_feats, args.n_colors, None)) + self.body = nn.Sequential(*m_body) def forward(self, x): diff --git a/src/option.py b/src/option.py index 87e93f2b7fccee17766e1fcb63070b879d798019..c9a0741eff21112fd859868c8a008140d1cfe765 100644 --- a/src/option.py +++ b/src/option.py @@ -122,11 +122,13 @@ parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') +parser.add_argument('--gclip', type=float, default=0, + help='gradient clipping threshold (0 = no clipping)') # Loss specifications parser.add_argument('--loss', type=str, default='1*L1', help='loss function configuration') -parser.add_argument('--skip_threshold', type=float, default='1e6', +parser.add_argument('--skip_threshold', type=float, default='1e8', help='skipping batch that has large error') # Log specifications diff --git a/src/template.py b/src/template.py index 6a9bea95565d0b5573ebda86554eaec871c54ec1..755a7bce970ffc19c3f6c06f9d28fa84281e5542 100644 --- a/src/template.py +++ b/src/template.py @@ -44,3 +44,10 @@ def set_template(args): args.n_feats = 64 args.chop = True + if args.template.find('VDSR') >= 0: + args.model = 'VDSR' + args.n_resblocks = 20 + args.n_feats = 64 + args.patch_size = 41 + args.lr = 1e-1 + diff --git a/src/trainer.py b/src/trainer.py index baff03c4c12a58b468f56cc4ff69a5cf0b9ca01a..a80d5f933fb24e19c0945a9cd70b5dfb9b52ef81 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -5,6 +5,7 @@ from decimal import Decimal import utility import torch +import torch.nn.utils as utils from tqdm import tqdm class Trainer(): @@ -49,13 +50,13 @@ class Trainer(): self.optimizer.zero_grad() sr = self.model(lr, idx_scale) loss = self.loss(sr, hr) - if loss.item() < self.args.skip_threshold * self.error_last: - loss.backward() - self.optimizer.step() - else: - print('Skip this batch {}! (Loss: {})'.format( - batch + 1, loss.item() - )) + loss.backward() + if self.args.gclip > 0: + utils.clip_grad_value_( + self.model.parameters(), + self.args.gclip + ) + self.optimizer.step() timer_model.hold()