diff --git a/experiment/model/EDSR_baseline_x2.pt b/experiment/model/EDSR_baseline_x2.pt deleted file mode 100644 index b72194e19148695edfcd2c32e79cbb6be61bb07d..0000000000000000000000000000000000000000 Binary files a/experiment/model/EDSR_baseline_x2.pt and /dev/null differ diff --git a/experiment/model/EDSR_baseline_x3.pt b/experiment/model/EDSR_baseline_x3.pt deleted file mode 100644 index c0f9cb36462b59b61e5c92595fd0ab294fcb8285..0000000000000000000000000000000000000000 Binary files a/experiment/model/EDSR_baseline_x3.pt and /dev/null differ diff --git a/experiment/model/EDSR_baseline_x4.pt b/experiment/model/EDSR_baseline_x4.pt deleted file mode 100644 index c69235dd5286cef22d9d7bd8c409fad7f2a1c6db..0000000000000000000000000000000000000000 Binary files a/experiment/model/EDSR_baseline_x4.pt and /dev/null differ diff --git a/src/data/sr291.py b/src/data/sr291.py new file mode 100644 index 0000000000000000000000000000000000000000..5e843178612d64ae75975853085eb01191bc0c21 --- /dev/null +++ b/src/data/sr291.py @@ -0,0 +1,6 @@ +from data import srdata + +class SR291(srdata.SRData): + def __init__(self, args, name='SR291', train=True, benchmark=False): + super(SR291, self).__init__(args, name=name) + diff --git a/src/data/video.py b/src/data/video.py new file mode 100644 index 0000000000000000000000000000000000000000..19588a78ae7c8e69bc283eea8cc9b5e06ff89df0 --- /dev/null +++ b/src/data/video.py @@ -0,0 +1,44 @@ +import os + +from data import common + +import cv2 +import numpy as np +import imageio + +import torch +import torch.utils.data as data + +class Video(data.Dataset): + def __init__(self, args, name='Video', train=False, benchmark=False): + self.args = args + self.name = name + self.scale = args.scale + self.idx_scale = 0 + self.train = False + self.do_eval = False + self.benchmark = benchmark + + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) + self.vidcap = cv2.VideoCapture(args.dir_demo) + self.n_frames = 0 + self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + + def __getitem__(self, idx): + success, lr = self.vidcap.read() + if success: + self.n_frames += 1 + lr, = common.set_channel(lr, n_channels=self.args.n_colors) + lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) + + return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) + else: + vidcap.release() + return None + + def __len__(self): + return self.total_frames + + def set_scale(self, idx_scale): + self.idx_scale = idx_scale + diff --git a/src/demo.sh b/src/demo.sh index dd584b6f00ba8ab8e13844df720cfe19d8bf1c33..70f12af086037706e9262b051360c6236fa582a8 100644 --- a/src/demo.sh +++ b/src/demo.sh @@ -2,19 +2,19 @@ #python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_baseline_x2 --reset # EDSR baseline model (x3) - from EDSR baseline model (x2) -#python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt +#python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] # EDSR baseline model (x4) - from EDSR baseline model (x2) -#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt +#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] # EDSR in the paper (x2) #python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset # EDSR in the paper (x3) - from EDSR (x2) -#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train ../experiment/EDSR_x2/model/model_best.pt +#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] # EDSR in the paper (x4) - from EDSR (x2) -#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train ../experiment/EDSR_x2/model/model_best.pt +#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] # MDSR baseline model #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models @@ -23,26 +23,26 @@ #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models # Standard benchmarks (Ex. EDSR_baseline_x4) -#python main.py --data_test Set5 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble -#python main.py --data_test Set14 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble -#python main.py --data_test B100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble -#python main.py --data_test Urban100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble -#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble - -#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble -#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble -#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble -#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble -#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble +#python main.py --data_test Set5 --scale 4 --pre_train download --test_only --self_ensemble +#python main.py --data_test Set14 --scale 4 --pre_train download --test_only --self_ensemble +#python main.py --data_test B100 --scale 4 --pre_train download --test_only --self_ensemble +#python main.py --data_test Urban100 --scale 4 --pre_train download --test_only --self_ensemble +#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble + +python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble +#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble +#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble +#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble +#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble # Test your own images -#python main.py --data_test Demo --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --save_results +#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results # Advanced - Test with JPEG images -#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train ../experiment/model/MDSR_baseline_jpeg.pt --test_only --save_results +#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results # Advanced - Training with adversarial loss -#python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train ../experiment/model/EDSR_baseline_x4.pt +#python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download # RDN BI model (x2) #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset diff --git a/src/loss/__init__.py b/src/loss/__init__.py index dc86746d91a531295f6b94947c620de2099d9fca..27c2e6b828db5e97bad29790bb3416204b5dc9fa 100644 --- a/src/loss/__init__.py +++ b/src/loss/__init__.py @@ -113,7 +113,7 @@ class Loss(nn.modules.loss._Loss): plt.xlabel('Epochs') plt.ylabel('Loss') plt.grid(True) - plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) + plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) plt.close(fig) def get_loss_module(self): diff --git a/src/main.py b/src/main.py index 1a2a613762b61b5b81395ea38b0c0fb0c4188650..5ff359de1111c25021df1acf2205add56eac56b4 100644 --- a/src/main.py +++ b/src/main.py @@ -6,18 +6,24 @@ import model import loss from option import args from trainer import Trainer +from videotester import VideoTester torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) -if checkpoint.ok: - loader = data.Data(args) +if args.data_test == 'video': model = model.Model(args, checkpoint) - loss = loss.Loss(args, checkpoint) if not args.test_only else None - t = Trainer(args, loader, model, loss, checkpoint) - while not t.terminate(): - t.train() - t.test() + t = VideoTester(args, model, checkpoint) + t.test() +else: + if checkpoint.ok: + loader = data.Data(args) + model = model.Model(args, checkpoint) + loss = loss.Loss(args, checkpoint) if not args.test_only else None + t = Trainer(args, loader, model, loss, checkpoint) + while not t.terminate(): + t.train() + t.test() - checkpoint.done() + checkpoint.done() diff --git a/src/model/__init__.py b/src/model/__init__.py index 4fa87f350eec998fd318de2d73a7aa109b668bb9..cdb6fd8bdface47ad2107453ccdb9f2d0c5b4477 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -3,6 +3,7 @@ from importlib import import_module import torch import torch.nn as nn +import torch.utils.model_zoo from torch.autograd import Variable class Model(nn.Module): @@ -29,7 +30,7 @@ class Model(nn.Module): self.model = nn.DataParallel(self.model, range(args.n_GPUs)) self.load( - ckp.dir, + ckp.get_path('model'), pre_train=args.pre_train, resume=args.resume, cpu=args.cpu @@ -39,16 +40,14 @@ class Model(nn.Module): def forward(self, x, idx_scale): self.idx_scale = idx_scale target = self.get_model() - if hasattr(target, 'set_scale'): - target.set_scale(idx_scale) - + if hasattr(target, 'set_scale'): target.set_scale(idx_scale) if self.self_ensemble and not self.training: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward - return self.forward_x8(x, forward_function) + return self.forward_x8(x, forward_function=forward_function) elif self.chop and not self.training: return self.forward_chop(x) else: @@ -66,22 +65,17 @@ class Model(nn.Module): def save(self, apath, epoch, is_best=False): target = self.get_model() - torch.save( - target.state_dict(), - os.path.join(apath, 'model', 'model_latest.pt') - ) + save_dirs = [os.path.join(apath, 'model_latest.pt')] + if is_best: - torch.save( - target.state_dict(), - os.path.join(apath, 'model', 'model_best.pt') - ) - + save_dirs.append(os.path.join(apath, 'model_best.pt')) if self.save_models: - torch.save( - target.state_dict(), - os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) + save_dirs.append( + os.path.join(apath, 'model_{}.pt'.format(epoch)) ) + for s in save_dirs: torch.save(target.state_dict(), s) + def load(self, apath, pre_train='.', resume=-1, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} @@ -91,31 +85,29 @@ class Model(nn.Module): load_from = None if resume == -1: load_from = torch.load( - os.path.join(apath, 'model', 'model_latest.pt'), + os.path.join(apath, 'model_latest.pt'), **kwargs ) elif resume == 0: - if pre_train != '.': - if pre_train == 'download': - print('Download the model') - dir_model = os.path.join('..', 'models') - os.makedirs(dir_model, exist_ok=True) - load_from = torch.utils.model_zoo.load_url( - self.get_model().url, - model_dir=dir_model, - **kwargs - ) - else: - print('Load the model from {}'.format(pre_train)) - load_from = torch.load(pre_train, **kwargs) + if pre_train == 'download': + print('Download the model') + dir_model = os.path.join('..', 'models') + os.makedirs(dir_model, exist_ok=True) + load_from = torch.utils.model_zoo.load_url( + self.get_model().url, + model_dir=dir_model, + **kwargs + ) + elif pre_train != '': + print('Load the model from {}'.format(pre_train)) + load_from = torch.load(pre_train, **kwargs) else: load_from = torch.load( - os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), + os.path.join(apath, 'model_{}.pt'.format(resume)), **kwargs ) - if load_from: - self.get_model().load_state_dict(load_from, strict=False) + if load_from: self.get_model().load_state_dict(load_from, strict=False) def forward_chop(self, *args, shave=10, min_size=160000): if self.input_large: @@ -152,8 +144,7 @@ class Model(nn.Module): if not list_y: list_y = [[_y] for _y in y] else: - for _list_y, _y in zip(list_y, y): - _list_y.append(_y) + for _list_y, _y in zip(list_y, y): _list_y.append(_y) h, w = scale * h, scale * w h_half, w_half = scale * h_half, scale * w_half @@ -196,8 +187,7 @@ class Model(nn.Module): list_x = [] for a in args: x = [a] - for tf in 'v', 'h', 't': - x.extend([_transform(_x, tf) for _x in x]) + for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) list_x.append(x) @@ -208,8 +198,7 @@ class Model(nn.Module): if not list_y: list_y = [[_y] for _y in y] else: - for _list_y, _y in zip(list_y, y): - _list_y.append(_y) + for _list_y, _y in zip(list_y, y): _list_y.append(_y) for _list_y in list_y: for i in range(len(_list_y)): diff --git a/src/model/common.py b/src/model/common.py index c20a79577c2915a4bf19cdab303a3475434904e4..79d0a0ec35b836cd7a0c81428af0e111998ce83f 100644 --- a/src/model/common.py +++ b/src/model/common.py @@ -12,24 +12,22 @@ def default_conv(in_channels, out_channels, kernel_size, bias=True): padding=(kernel_size//2), bias=bias) class MeanShift(nn.Conv2d): - def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) - self.weight.data = torch.eye(3).view(3, 3, 1, 1) - self.weight.data.div_(std.view(3, 1, 1, 1)) - self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) - self.bias.data.div_(std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( - self, in_channels, out_channels, kernel_size, stride=1, bias=False, + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): - m = [nn.Conv2d( - in_channels, out_channels, kernel_size, - padding=(kernel_size//2), stride=stride, bias=bias) - ] + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) diff --git a/src/model/edsr.py b/src/model/edsr.py index ead8023b0f3cd1e02f70b3e90f2d5f0e1355881b..4ed986e2f310a965ce9519a1c7fafff5d81bf1ba 100644 --- a/src/model/edsr.py +++ b/src/model/edsr.py @@ -18,17 +18,14 @@ class EDSR(nn.Module): def __init__(self, args, conv=common.default_conv): super(EDSR, self).__init__() - n_resblock = args.n_resblocks + n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 scale = args.scale[0] act = nn.ReLU(True) self.url = url['r{}f{}x{}'.format(n_resblocks, n_feats, scale)] - - rgb_mean = (0.4488, 0.4371, 0.4040) - rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) - self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) + self.sub_mean = common.MeanShift(args.rgb_range) + self.add_mean = common.MeanShift(args.rgb_range, sign=1) # define head module m_head = [conv(args.n_colors, n_feats, kernel_size)] @@ -37,7 +34,7 @@ class EDSR(nn.Module): m_body = [ common.ResBlock( conv, n_feats, kernel_size, act=act, res_scale=args.res_scale - ) for _ in range(n_resblock) + ) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) diff --git a/src/model/mdsr.py b/src/model/mdsr.py index 354059d6b92b79599a4baebac0cbaa36f618a2f4..4a5c86ec294719e3025afc631f7d4381b22d2ea0 100644 --- a/src/model/mdsr.py +++ b/src/model/mdsr.py @@ -16,14 +16,11 @@ class MDSR(nn.Module): n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 + act = nn.ReLU(True) self.scale_idx = 0 self.url = url['r{}f{}'.format(n_resblocks, n_feats)] - - act = nn.ReLU(True) - - rgb_mean = (0.4488, 0.4371, 0.4040) - rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) + self.sub_mean = common.MeanShift(args.rgb_range) + self.add_mean = common.MeanShift(args.rgb_range, sign=1) m_head = [conv(args.n_colors, n_feats, kernel_size)] @@ -47,8 +44,6 @@ class MDSR(nn.Module): m_tail = [conv(n_feats, args.n_colors, kernel_size)] - self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) - self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) diff --git a/src/model/rcan.py b/src/model/rcan.py index 1971c58b68f83ca3c22db3ed0bbe0f61468ec31f..a53e6f8e5822482d1b891eec5029bd3f0a70a862 100644 --- a/src/model/rcan.py +++ b/src/model/rcan.py @@ -79,9 +79,7 @@ class RCAN(nn.Module): act = nn.ReLU(True) # RGB mean for DIV2K - rgb_mean = (0.4488, 0.4371, 0.4040) - rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) + self.sub_mean = common.MeanShift(args.rgb_range) # define head module modules_head = [conv(args.n_colors, n_feats, kernel_size)] @@ -99,7 +97,7 @@ class RCAN(nn.Module): common.Upsampler(conv, scale, n_feats, act=False), conv(n_feats, args.n_colors, kernel_size)] - self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) + self.add_mean = common.MeanShift(args.rgb_range, sign=1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) diff --git a/src/model/vdsr.py b/src/model/vdsr.py new file mode 100644 index 0000000000000000000000000000000000000000..46ca41d75bacb51e63afbbc87b6ccdd9472255a6 --- /dev/null +++ b/src/model/vdsr.py @@ -0,0 +1,44 @@ +from model import common + +import torch.nn as nn + +url = { + 'r20f64': '' +} + +def make_model(args, parent=False): + return VDSR(args) + +class VDSR(nn.Module): + def __init__(self, args, conv=common.default_conv): + super(VDSR, self).__init__() + + n_resblocks = args.n_resblocks + n_feats = args.n_feats + kernel_size = 3 + self.url = url['r{}f{}'.format(n_resblocks, n_feats)] + self.sub_mean = common.MeanShift(args.rgb_range) + self.add_mean = common.MeanShift(args.rgb_range, sign=1) + + # define body module + m_body = [] + m_body.append(common.BasicBlock( + conv, args.n_colors, n_feats, kernel_size, bn=False + )) + for _ in range(n_resblocks - 2): + m_body.append(common.BasicBlock( + conv, n_feats, n_feats, kernel_size, bn=False + )) + m_body.append(common.BasicBlock( + conv, n_feats, args.n_colors, kernel_size, bn=False, act=None + )) + self.body = nn.Sequential(*m_body) + + def forward(self, x): + x = self.sub_mean(x) + res = self.body(x) + res += x + x = self.add_mean(res) + + return x + diff --git a/src/option.py b/src/option.py index 93fc4053cfd6a08bdae06dfda14f7ed54a5df544..87e93f2b7fccee17766e1fcb63070b879d798019 100644 --- a/src/option.py +++ b/src/option.py @@ -50,7 +50,7 @@ parser.add_argument('--model', default='EDSR', parser.add_argument('--act', type=str, default='relu', help='activation function') -parser.add_argument('--pre_train', type=str, default='.', +parser.add_argument('--pre_train', type=str, default='', help='pre-trained model directory') parser.add_argument('--extend', type=str, default='.', help='pre-trained model directory') diff --git a/src/trainer.py b/src/trainer.py index c35b2c7bf8daf71b0de21c984edab52e8850cc13..c610294c40fadd9eaefb139c0a33596d450ec4e3 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -5,7 +5,6 @@ from decimal import Decimal import utility import torch -from torch.autograd import Variable from tqdm import tqdm class Trainer(): diff --git a/src/utility.py b/src/utility.py index b8922b7e4435abd1fdfef625dba035117e48470e..b5033c65ecafb4a9ffdebc2bacd72d7869be3a00 100644 --- a/src/utility.py +++ b/src/utility.py @@ -64,27 +64,30 @@ class checkpoint(): if not os.path.exists(path): os.makedirs(path) _make_dir(self.dir) - _make_dir(self.dir + '/model') - _make_dir(self.dir + '/results') + _make_dir(self.get_path('model')) + _make_dir(self.get_path('results')) - open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' - self.log_file = open(self.dir + '/log.txt', open_type) - with open(self.dir + '/config.txt', open_type) as f: + open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' + self.log_file = open(self.get_path('log.txt'), open_type) + with open(self.get_path('config.txt'), open_type) as f: f.write(now + '\n\n') for arg in vars(args): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') + def get_path(self, *subdir): + return os.path.join(self.dir, *subdir) + def save(self, trainer, epoch, is_best=False): - trainer.model.save(self.dir, epoch, is_best=is_best) + trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) trainer.loss.plot_loss(self.dir, epoch) self.plot_psnr(epoch) - torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) + torch.save(self.log, self.get_path('psnr_log.pt')) torch.save( trainer.optimizer.state_dict(), - os.path.join(self.dir, 'optimizer.pt') + self.get_path('optimizer.pt') ) def add_log(self, log): @@ -95,7 +98,7 @@ class checkpoint(): self.log_file.write(log + '\n') if refresh: self.log_file.close() - self.log_file = open(self.dir + '/log.txt', 'a') + self.log_file = open(self.get_path('log.txt'), 'a') def done(self): self.log_file.close() @@ -115,14 +118,14 @@ class checkpoint(): plt.xlabel('Epochs') plt.ylabel('PSNR') plt.grid(True) - plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) + plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test))) plt.close(fig) def save_results(self, filename, save_list, scale): - filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) + filename = self.get_path('results', '{}_x{}_'.format(filename, scale)) postfix = ('SR', 'LR', 'HR') for v, p in zip(save_list, postfix): - normalized = v[0].data.mul(255 / self.args.rgb_range) + normalized = v[0].mul(255 / self.args.rgb_range) ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() misc.imsave('{}{}.png'.format(filename, p), ndarr) diff --git a/src/videotester.py b/src/videotester.py new file mode 100644 index 0000000000000000000000000000000000000000..2732a0d411c46dcbfc414a0aa0c1baeec420dad3 --- /dev/null +++ b/src/videotester.py @@ -0,0 +1,78 @@ +import os +import math + +import utility +from data import common + +import torch +import cv2 + +from tqdm import tqdm + +class VideoTester(): + def __init__(self, args, my_model, ckp): + self.args = args + self.scale = args.scale + + self.ckp = ckp + self.model = my_model + + self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) + + def test(self): + self.ckp.write_log('\nEvaluation on video:') + self.model.eval() + + timer_test = utility.timer() + torch.set_grad_enabled(False) + for idx_scale, scale in enumerate(self.scale): + vidcap = cv2.VideoCapture(self.args.dir_demo) + total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + total_batches = math.ceil(total_frames / self.args.batch_size) + vidwri = cv2.VideoWriter( + self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), + cv2.VideoWriter_fourcc(*'XVID'), + int(vidcap.get(cv2.CAP_PROP_FPS)), + ( + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + ) + ) + + tqdm_test = tqdm(range(total_batches), ncols=80) + for _ in tqdm_test: + fs = [] + for _ in range(self.args.batch_size): + success, lr = vidcap.read() + if success: + fs.append(lr) + else: + break + + fs = common.set_channel(*fs, n_channels=self.args.n_colors) + fs = common.np2Tensor(*fs, rgb_range=self.args.rgb_range) + lr = torch.stack(fs, dim=0) + lr, = self.prepare(lr) + sr = self.model(lr, idx_scale) + sr = utility.quantize(sr, self.args.rgb_range) + + for i in range(self.args.batch_size): + normalized = sr[i].mul(255 / self.args.rgb_range) + ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() + vidwri.write(ndarr) + + self.vidcap.release() + self.vidwri.release() + + self.ckp.write_log( + 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True + ) + + def prepare(self, *args): + device = torch.device('cpu' if self.args.cpu else 'cuda') + def _prepare(tensor): + if self.args.precision == 'half': tensor = tensor.half() + return tensor.to(device) + + return [_prepare(a) for a in args] +