Skip to content
Snippets Groups Projects
Commit a91ccea9 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

now testing jpeg augmentation

parent d99e5756
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
from importlib import import_module
from dataloader import MSDataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import ConcatDataset
# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
def __init__(self, datasets):
super(MyConcatDataset, self).__init__(datasets)
def set_scale(self, idx_scale):
for d in self.datasets:
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
class Data:
def __init__(self, args):
self.loader_train = None
if not args.test_only:
module_train = import_module('data.' + args.data_train.lower())
trainset = getattr(module_train, args.data_train)(args)
datasets = []
for d in args.data_train:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
datasets.append(getattr(m, module_name)(args, name=d))
self.loader_train = MSDataLoader(
args,
trainset,
MyConcatDataset(datasets),
batch_size=args.batch_size,
shuffle=True,
pin_memory=not args.cpu
)
if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
module_test = import_module('data.benchmark')
testset = getattr(module_test, 'Benchmark')(
args, train=False, name=args.data_test
)
self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_test = import_module('data.' + args.data_test.lower())
testset = getattr(module_test, args.data_test)(args, train=False)
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test = MSDataLoader(
self.loader_test.append(MSDataLoader(
args,
testset,
batch_size=1,
shuffle=False,
pin_memory=not args.cpu
)
))
......@@ -15,7 +15,6 @@ class Demo(data.Dataset):
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.do_eval = False
self.benchmark = benchmark
self.filelist = []
......@@ -25,8 +24,7 @@ class Demo(data.Dataset):
self.filelist.sort()
def __getitem__(self, idx):
filename = os.path.split(self.filelist[idx])[-1]
filename, _ = os.path.splitext(filename)
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
lr = imageio.imread(self.filelist[idx])
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
......
......@@ -3,6 +3,16 @@ from data import srdata
class DIV2K(srdata.SRData):
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
super(DIV2K, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
......@@ -17,8 +27,6 @@ class DIV2K(srdata.SRData):
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
if self.input_large:
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubicL')
else:
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
if self.input_large: self.dir_lr += 'L'
import os
from data import srdata
from data import div2k
class DIV2KJPEG(div2k.DIV2K):
def __init__(self, args, name='', train=True, benchmark=False):
self.q_factor = int(name.replace('DIV2K-Q', ''))
super(DIV2KJPEG, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'DIV2K')
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
)
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.jpg')
......@@ -22,16 +22,6 @@ class SRData(data.Dataset):
self.scale = args.scale
self.idx_scale = 0
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
......@@ -85,8 +75,9 @@ class SRData(data.Dataset):
)
if train:
self.repeat \
= args.test_every // (len(self.images_hr) // args.batch_size)
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
self.repeat = max(n_patches // n_images, 1)
# Below functions as used to prepare images
def _scan(self):
......@@ -106,10 +97,10 @@ class SRData(data.Dataset):
return names_hr, names_lr
def _set_filesystem(self, dir_data):
bicubic_type = 'LR_bicubic' if not self.input_large else 'LR_bicubicL'
self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, bicubic_type)
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.png')
def _name_hrbin(self):
......
# EDSR baseline model (x2)
#python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_baseline_x2 --reset
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset
python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
# EDSR baseline model (x3) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
# EDSR baseline model (x4) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
# EDSR in the paper (x2)
#python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
# EDSR in the paper (x3) - from EDSR (x2)
#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]
# EDSR in the paper (x4) - from EDSR (x2)
#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
# MDSR baseline model
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
......@@ -25,11 +26,7 @@
# Standard benchmarks (Ex. EDSR_baseline_x4)
#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
# Test your own images
#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
......@@ -38,7 +35,7 @@
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
# Advanced - Training with adversarial loss
#python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
# RDN BI model (x2)
#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
......
......@@ -144,11 +144,15 @@ parser.add_argument('--print_every', type=int, default=100,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',
help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
args = parser.parse_args()
template.set_template(args)
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
......
......@@ -74,61 +74,58 @@ class Trainer():
self.error_last = self.loss.log[-1, -1]
def test(self):
torch.set_grad_enabled(False)
epoch = self.scheduler.last_epoch + 1
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(torch.zeros(1, len(self.scale)))
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
with torch.no_grad():
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
eval_acc = 0
self.loader_test.dataset.set_scale(idx_scale)
tqdm_test = tqdm(self.loader_test, ncols=80)
for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
filename = filename[0]
no_eval = (hr.nelement() == 1)
d.dataset.set_scale(idx_scale)
for lr, hr, filename, _ in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
if not no_eval:
eval_acc += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range,
benchmark=self.loader_test.dataset.benchmark
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
save_list.extend([lr, hr])
if self.args.save_gt: save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(filename, save_list, scale)
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
self.args.data_test,
d.dataset.name,
scale,
self.ckp.log[-1, idx_scale],
best[0][idx_scale],
best[1][idx_scale] + 1
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log(
'Forward time: {:.2f}s\n'.format(timer_test.toc())
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
if self.args.save_results: self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
......
......@@ -50,25 +50,23 @@ class checkpoint():
if args.load == '.':
if args.save == '.': args.save = now
self.dir = '../experiment/' + args.save
self.dir = os.path.join('..', 'experiment', args.save)
else:
self.dir = '../experiment/' + args.load
self.dir = os.path.join('..', 'experiment', args.load)
if not os.path.exists(self.dir):
args.load = '.'
else:
self.log = torch.load(self.dir + '/psnr_log.pt')
self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log)))
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = '.'
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
_make_dir(self.get_path('model'))
_make_dir(self.get_path('results'))
os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True)
for d in args.data_test:
os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type)
......@@ -110,20 +108,21 @@ class checkpoint():
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
label = 'SR on {}'.format(self.args.data_test)
for idx_data, d in enumerate(self.args.data_test):
label = 'SR on {}'.format(d)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_scale].numpy(),
self.log[:, idx_data, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test)))
plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig)
def begin_background(self):
......@@ -148,8 +147,13 @@ class checkpoint():
while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join()
def save_results(self, filename, save_list, scale):
filename = self.get_path('results', '{}_x{}_'.format(filename, scale))
def save_results(self, dataset, filename, save_list, scale):
if self.args.save_results:
filename = self.get_path(
'results-{}'.format(dataset.dataset.name),
'{}_x{}_'.format(filename, scale)
)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range)
......@@ -160,16 +164,18 @@ def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
diff = (sr - hr).data.div(rgb_range)
if benchmark:
def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if hr.nelement() == 1: return 0
diff = (sr - hr) / rgb_range
if dataset and dataset.dataset.benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff *= (convert / 256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
......
......@@ -20,11 +20,12 @@ class VideoTester():
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self):
torch.set_grad_enabled(False)
self.ckp.write_log('\nEvaluation on video:')
self.model.eval()
timer_test = utility.timer()
torch.set_grad_enabled(False)
for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
......@@ -57,8 +58,9 @@ class VideoTester():
vidwri.release()
self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment