diff --git a/src/data/__init__.py b/src/data/__init__.py index cfe6faa6ef9c8d3d0aa089e15296b78bfc9ff667..26b43f2992d89b3605b2c2663159febeebc35555 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,36 +1,49 @@ from importlib import import_module - from dataloader import MSDataLoader -from torch.utils.data.dataloader import default_collate +from torch.utils.data import ConcatDataset + +# This is a simple wrapper function for ConcatDataset +class MyConcatDataset(ConcatDataset): + def __init__(self, datasets): + super(MyConcatDataset, self).__init__(datasets) + + def set_scale(self, idx_scale): + for d in self.datasets: + if hasattr(d, 'set_scale'): d.set_scale(idx_scale) class Data: def __init__(self, args): self.loader_train = None if not args.test_only: - module_train = import_module('data.' + args.data_train.lower()) - trainset = getattr(module_train, args.data_train)(args) + datasets = [] + for d in args.data_train: + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' + m = import_module('data.' + module_name.lower()) + datasets.append(getattr(m, module_name)(args, name=d)) + self.loader_train = MSDataLoader( args, - trainset, + MyConcatDataset(datasets), batch_size=args.batch_size, shuffle=True, pin_memory=not args.cpu ) - if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']: - module_test = import_module('data.benchmark') - testset = getattr(module_test, 'Benchmark')( - args, train=False, name=args.data_test - ) - else: - module_test = import_module('data.' + args.data_test.lower()) - testset = getattr(module_test, args.data_test)(args, train=False) + self.loader_test = [] + for d in args.data_test: + if d in ['Set5', 'Set14', 'B100', 'Urban100']: + m = import_module('data.benchmark') + testset = getattr(m, 'Benchmark')(args, train=False, name=d) + else: + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' + m = import_module('data.' + module_name.lower()) + testset = getattr(m, module_name)(args, train=False, name=d) - self.loader_test = MSDataLoader( - args, - testset, - batch_size=1, - shuffle=False, - pin_memory=not args.cpu - ) + self.loader_test.append(MSDataLoader( + args, + testset, + batch_size=1, + shuffle=False, + pin_memory=not args.cpu + )) diff --git a/src/data/demo.py b/src/data/demo.py index 4849a84ffc073c3e34845cd911eef023eb9cc2af..ff1929c3c7d9520e860bd64a27e26394f8b1a343 100644 --- a/src/data/demo.py +++ b/src/data/demo.py @@ -15,7 +15,6 @@ class Demo(data.Dataset): self.scale = args.scale self.idx_scale = 0 self.train = False - self.do_eval = False self.benchmark = benchmark self.filelist = [] @@ -25,8 +24,7 @@ class Demo(data.Dataset): self.filelist.sort() def __getitem__(self, idx): - filename = os.path.split(self.filelist[idx])[-1] - filename, _ = os.path.splitext(filename) + filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] lr = imageio.imread(self.filelist[idx]) lr, = common.set_channel(lr, n_channels=self.args.n_colors) lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) diff --git a/src/data/div2k.py b/src/data/div2k.py index 7ec07ce931b00b85ab516c3d53e363537f6f56b4..b80592cdf6cfb9bb9ad026122a950239799e0de2 100644 --- a/src/data/div2k.py +++ b/src/data/div2k.py @@ -3,6 +3,16 @@ from data import srdata class DIV2K(srdata.SRData): def __init__(self, args, name='DIV2K', train=True, benchmark=False): + data_range = [r.split('-') for r in args.data_range.split('/')] + if train: + data_range = data_range[0] + else: + if args.test_only and len(data_range) == 1: + data_range = data_range[0] + else: + data_range = data_range[1] + + self.begin, self.end = list(map(lambda x: int(x), data_range)) super(DIV2K, self).__init__( args, name=name, train=train, benchmark=benchmark ) @@ -17,8 +27,6 @@ class DIV2K(srdata.SRData): def _set_filesystem(self, dir_data): super(DIV2K, self)._set_filesystem(dir_data) self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') - if self.input_large: - self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubicL') - else: - self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') + self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') + if self.input_large: self.dir_lr += 'L' diff --git a/src/data/div2kjpeg.py b/src/data/div2kjpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..ed542dddb2c9465dfa7460fbf9f3fa56d4a7785c --- /dev/null +++ b/src/data/div2kjpeg.py @@ -0,0 +1,20 @@ +import os +from data import srdata +from data import div2k + +class DIV2KJPEG(div2k.DIV2K): + def __init__(self, args, name='', train=True, benchmark=False): + self.q_factor = int(name.replace('DIV2K-Q', '')) + super(DIV2KJPEG, self).__init__( + args, name=name, train=train, benchmark=benchmark + ) + + def _set_filesystem(self, dir_data): + self.apath = os.path.join(dir_data, 'DIV2K') + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') + self.dir_lr = os.path.join( + self.apath, 'DIV2K_Q{}'.format(self.q_factor) + ) + if self.input_large: self.dir_lr += 'L' + self.ext = ('.png', '.jpg') + diff --git a/src/data/srdata.py b/src/data/srdata.py index f8489342a4a41204760f77ca6345f62dcb025982..97723cfd37161e3e006a58127b2654b21864b11a 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -22,16 +22,6 @@ class SRData(data.Dataset): self.scale = args.scale self.idx_scale = 0 - data_range = [r.split('-') for r in args.data_range.split('/')] - if train: - data_range = data_range[0] - else: - if args.test_only and len(data_range) == 1: - data_range = data_range[0] - else: - data_range = data_range[1] - - self.begin, self.end = list(map(lambda x: int(x), data_range)) self._set_filesystem(args.dir_data) if args.ext.find('img') < 0: path_bin = os.path.join(self.apath, 'bin') @@ -85,8 +75,9 @@ class SRData(data.Dataset): ) if train: - self.repeat \ - = args.test_every // (len(self.images_hr) // args.batch_size) + n_patches = args.batch_size * args.test_every + n_images = len(args.data_train) * len(self.images_hr) + self.repeat = max(n_patches // n_images, 1) # Below functions as used to prepare images def _scan(self): @@ -106,10 +97,10 @@ class SRData(data.Dataset): return names_hr, names_lr def _set_filesystem(self, dir_data): - bicubic_type = 'LR_bicubic' if not self.input_large else 'LR_bicubicL' self.apath = os.path.join(dir_data, self.name) self.dir_hr = os.path.join(self.apath, 'HR') - self.dir_lr = os.path.join(self.apath, bicubic_type) + self.dir_lr = os.path.join(self.apath, 'LR_bicubic') + if self.input_large: self.dir_lr += 'L' self.ext = ('.png', '.png') def _name_hrbin(self): diff --git a/src/demo.sh b/src/demo.sh index cc00a165b6b61edd2c9b01b3a670520987872d30..cab75ea6a99f72c6e4c9250b00199f1616beec95 100644 --- a/src/demo.sh +++ b/src/demo.sh @@ -1,20 +1,21 @@ # EDSR baseline model (x2) -#python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_baseline_x2 --reset +#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset +python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 # EDSR baseline model (x3) - from EDSR baseline model (x2) -#python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] +#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] # EDSR baseline model (x4) - from EDSR baseline model (x2) -#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] +#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] # EDSR in the paper (x2) -#python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset +#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset # EDSR in the paper (x3) - from EDSR (x2) -#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] +#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir] # EDSR in the paper (x4) - from EDSR (x2) -#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] +#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] # MDSR baseline model #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models @@ -25,11 +26,7 @@ # Standard benchmarks (Ex. EDSR_baseline_x4) #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble -#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble -#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble -#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble -#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble -#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble +#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble # Test your own images #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results @@ -38,7 +35,7 @@ #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results # Advanced - Training with adversarial loss -#python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download +#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download # RDN BI model (x2) #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset diff --git a/src/option.py b/src/option.py index c9a0741eff21112fd859868c8a008140d1cfe765..afe22277d7d1bf2a3f550c3ff5a061d2d795de6e 100644 --- a/src/option.py +++ b/src/option.py @@ -144,11 +144,15 @@ parser.add_argument('--print_every', type=int, default=100, help='how many batches to wait before logging training status') parser.add_argument('--save_results', action='store_true', help='save output results') +parser.add_argument('--save_gt', action='store_true', + help='save low-resolution and high-resolution images together') args = parser.parse_args() template.set_template(args) args.scale = list(map(lambda x: int(x), args.scale.split('+'))) +args.data_train = args.data_train.split('+') +args.data_test = args.data_test.split('+') if args.epochs == 0: args.epochs = 1e8 diff --git a/src/trainer.py b/src/trainer.py index a80d5f933fb24e19c0945a9cd70b5dfb9b52ef81..77ae2de0c9d8c965ffdd16554a9957dcdc24774c 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -74,67 +74,64 @@ class Trainer(): self.error_last = self.loss.log[-1, -1] def test(self): + torch.set_grad_enabled(False) + epoch = self.scheduler.last_epoch + 1 self.ckp.write_log('\nEvaluation:') - self.ckp.add_log(torch.zeros(1, len(self.scale))) + self.ckp.add_log( + torch.zeros(1, len(self.loader_test), len(self.scale)) + ) self.model.eval() timer_test = utility.timer() if self.args.save_results: self.ckp.begin_background() - with torch.no_grad(): + for idx_data, d in enumerate(self.loader_test): for idx_scale, scale in enumerate(self.scale): - eval_acc = 0 - self.loader_test.dataset.set_scale(idx_scale) - tqdm_test = tqdm(self.loader_test, ncols=80) - for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test): - filename = filename[0] - no_eval = (hr.nelement() == 1) - + d.dataset.set_scale(idx_scale) + for lr, hr, filename, _ in tqdm(d, ncols=80): lr, hr = self.prepare(lr, hr) sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range) save_list = [sr] - if not no_eval: - eval_acc += utility.calc_psnr( - sr, hr, scale, self.args.rgb_range, - benchmark=self.loader_test.dataset.benchmark - ) - save_list.extend([lr, hr]) + self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( + sr, hr, scale, self.args.rgb_range, dataset=d + ) + if self.args.save_gt: save_list.extend([lr, hr]) - if self.args.save_results: - self.ckp.save_results(filename, save_list, scale) + self.ckp.save_results(d, filename[0], save_list, scale) - self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) + self.ckp.log[-1, idx_data, idx_scale] /= len(d) best = self.ckp.log.max(0) self.ckp.write_log( '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( - self.args.data_test, + d.dataset.name, scale, - self.ckp.log[-1, idx_scale], - best[0][idx_scale], - best[1][idx_scale] + 1 + self.ckp.log[-1, idx_data, idx_scale], + best[0][idx_data, idx_scale], + best[1][idx_data, idx_scale] + 1 ) ) - self.ckp.write_log( - 'Forward time: {:.2f}s\n'.format(timer_test.toc()) - ) - + self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) self.ckp.write_log('Saving...') + if self.args.save_results: self.ckp.end_background() if not self.args.test_only: - self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) + self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) + self.ckp.write_log( - 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True ) + torch.set_grad_enabled(True) + def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda') def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) - + return [_prepare(a) for a in args] def terminate(self): diff --git a/src/utility.py b/src/utility.py index 1793e26434ba4326b4136c699835fadd25876e3e..866c737b376c22a7c76774b9e356c87fce4e86fd 100644 --- a/src/utility.py +++ b/src/utility.py @@ -50,25 +50,23 @@ class checkpoint(): if args.load == '.': if args.save == '.': args.save = now - self.dir = '../experiment/' + args.save + self.dir = os.path.join('..', 'experiment', args.save) else: - self.dir = '../experiment/' + args.load + self.dir = os.path.join('..', 'experiment', args.load) if not os.path.exists(self.dir): args.load = '.' else: - self.log = torch.load(self.dir + '/psnr_log.pt') + self.log = torch.load(self.get_path('psnr_log.pt')) print('Continue from epoch {}...'.format(len(self.log))) if args.reset: os.system('rm -rf ' + self.dir) args.load = '.' - def _make_dir(path): - if not os.path.exists(path): os.makedirs(path) - - _make_dir(self.dir) - _make_dir(self.get_path('model')) - _make_dir(self.get_path('results')) + os.makedirs(self.dir, exist_ok=True) + os.makedirs(self.get_path('model'), exist_ok=True) + for d in args.data_test: + os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' self.log_file = open(self.get_path('log.txt'), open_type) @@ -110,21 +108,22 @@ class checkpoint(): def plot_psnr(self, epoch): axis = np.linspace(1, epoch, epoch) - label = 'SR on {}'.format(self.args.data_test) - fig = plt.figure() - plt.title(label) - for idx_scale, scale in enumerate(self.args.scale): - plt.plot( - axis, - self.log[:, idx_scale].numpy(), - label='Scale {}'.format(scale) - ) - plt.legend() - plt.xlabel('Epochs') - plt.ylabel('PSNR') - plt.grid(True) - plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test))) - plt.close(fig) + for idx_data, d in enumerate(self.args.data_test): + label = 'SR on {}'.format(d) + fig = plt.figure() + plt.title(label) + for idx_scale, scale in enumerate(self.args.scale): + plt.plot( + axis, + self.log[:, idx_data, idx_scale].numpy(), + label='Scale {}'.format(scale) + ) + plt.legend() + plt.xlabel('Epochs') + plt.ylabel('PSNR') + plt.grid(True) + plt.savefig(self.get_path('test_{}.pdf'.format(d))) + plt.close(fig) def begin_background(self): self.queue = Queue() @@ -148,28 +147,35 @@ class checkpoint(): while not self.queue.empty(): time.sleep(1) for p in self.process: p.join() - def save_results(self, filename, save_list, scale): - filename = self.get_path('results', '{}_x{}_'.format(filename, scale)) - postfix = ('SR', 'LR', 'HR') - for v, p in zip(save_list, postfix): - normalized = v[0].mul(255 / self.args.rgb_range) - tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() - self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) + def save_results(self, dataset, filename, save_list, scale): + if self.args.save_results: + filename = self.get_path( + 'results-{}'.format(dataset.dataset.name), + '{}_x{}_'.format(filename, scale) + ) + + postfix = ('SR', 'LR', 'HR') + for v, p in zip(save_list, postfix): + normalized = v[0].mul(255 / self.args.rgb_range) + tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() + self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) def quantize(img, rgb_range): pixel_range = 255 / rgb_range return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) -def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): - diff = (sr - hr).data.div(rgb_range) - if benchmark: +def calc_psnr(sr, hr, scale, rgb_range, dataset=None): + if hr.nelement() == 1: return 0 + + diff = (sr - hr) / rgb_range + if dataset and dataset.dataset.benchmark: shave = scale if diff.size(1) > 1: convert = diff.new(1, 3, 1, 1) convert[0, 0, 0, 0] = 65.738 convert[0, 1, 0, 0] = 129.057 convert[0, 2, 0, 0] = 25.064 - diff.mul_(convert).div_(256) + diff *= (convert / 256) diff = diff.sum(dim=1, keepdim=True) else: shave = scale + 6 diff --git a/src/videotester.py b/src/videotester.py index 0d20fba7822d96b3115156f57dcc8b5b065a558f..d94bd84dbea3626a2fec7818db238ca063a9f660 100644 --- a/src/videotester.py +++ b/src/videotester.py @@ -20,11 +20,12 @@ class VideoTester(): self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) def test(self): + torch.set_grad_enabled(False) + self.ckp.write_log('\nEvaluation on video:') self.model.eval() timer_test = utility.timer() - torch.set_grad_enabled(False) for idx_scale, scale in enumerate(self.scale): vidcap = cv2.VideoCapture(self.args.dir_demo) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) @@ -57,8 +58,9 @@ class VideoTester(): vidwri.release() self.ckp.write_log( - 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True + 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True ) + torch.set_grad_enabled(True) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda')