Skip to content
Snippets Groups Projects
Commit ab6b0ad7 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

need test

parent 98c3cf65
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
File deleted
File deleted
File deleted
from data import srdata
class SR291(srdata.SRData):
def __init__(self, args, name='SR291', train=True, benchmark=False):
super(SR291, self).__init__(args, name=name)
import os
from data import common
import cv2
import numpy as np
import imageio
import torch
import torch.utils.data as data
class Video(data.Dataset):
def __init__(self, args, name='Video', train=False, benchmark=False):
self.args = args
self.name = name
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.do_eval = False
self.benchmark = benchmark
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
self.vidcap = cv2.VideoCapture(args.dir_demo)
self.n_frames = 0
self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
def __getitem__(self, idx):
success, lr = self.vidcap.read()
if success:
self.n_frames += 1
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
else:
vidcap.release()
return None
def __len__(self):
return self.total_frames
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
...@@ -2,19 +2,19 @@ ...@@ -2,19 +2,19 @@
#python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_baseline_x2 --reset #python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_baseline_x2 --reset
# EDSR baseline model (x3) - from EDSR baseline model (x2) # EDSR baseline model (x3) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt #python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
# EDSR baseline model (x4) - from EDSR baseline model (x2) # EDSR baseline model (x4) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt #python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
# EDSR in the paper (x2) # EDSR in the paper (x2)
#python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset #python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
# EDSR in the paper (x3) - from EDSR (x2) # EDSR in the paper (x3) - from EDSR (x2)
#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train ../experiment/EDSR_x2/model/model_best.pt #python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
# EDSR in the paper (x4) - from EDSR (x2) # EDSR in the paper (x4) - from EDSR (x2)
#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train ../experiment/EDSR_x2/model/model_best.pt #python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
# MDSR baseline model # MDSR baseline model
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
...@@ -23,26 +23,26 @@ ...@@ -23,26 +23,26 @@
#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
# Standard benchmarks (Ex. EDSR_baseline_x4) # Standard benchmarks (Ex. EDSR_baseline_x4)
#python main.py --data_test Set5 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test Set5 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test Set14 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test B100 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test Urban100 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble #python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble #python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble #python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble #python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
# Test your own images # Test your own images
#python main.py --data_test Demo --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --save_results #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
# Advanced - Test with JPEG images # Advanced - Test with JPEG images
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train ../experiment/model/MDSR_baseline_jpeg.pt --test_only --save_results #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
# Advanced - Training with adversarial loss # Advanced - Training with adversarial loss
#python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train ../experiment/model/EDSR_baseline_x4.pt #python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
# RDN BI model (x2) # RDN BI model (x2)
#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
......
...@@ -113,7 +113,7 @@ class Loss(nn.modules.loss._Loss): ...@@ -113,7 +113,7 @@ class Loss(nn.modules.loss._Loss):
plt.xlabel('Epochs') plt.xlabel('Epochs')
plt.ylabel('Loss') plt.ylabel('Loss')
plt.grid(True) plt.grid(True)
plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
plt.close(fig) plt.close(fig)
def get_loss_module(self): def get_loss_module(self):
......
...@@ -6,10 +6,16 @@ import model ...@@ -6,10 +6,16 @@ import model
import loss import loss
from option import args from option import args
from trainer import Trainer from trainer import Trainer
from videotester import VideoTester
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args) checkpoint = utility.checkpoint(args)
if args.data_test == 'video':
model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok: if checkpoint.ok:
loader = data.Data(args) loader = data.Data(args)
model = model.Model(args, checkpoint) model = model.Model(args, checkpoint)
......
...@@ -3,6 +3,7 @@ from importlib import import_module ...@@ -3,6 +3,7 @@ from importlib import import_module
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo
from torch.autograd import Variable from torch.autograd import Variable
class Model(nn.Module): class Model(nn.Module):
...@@ -29,7 +30,7 @@ class Model(nn.Module): ...@@ -29,7 +30,7 @@ class Model(nn.Module):
self.model = nn.DataParallel(self.model, range(args.n_GPUs)) self.model = nn.DataParallel(self.model, range(args.n_GPUs))
self.load( self.load(
ckp.dir, ckp.get_path('model'),
pre_train=args.pre_train, pre_train=args.pre_train,
resume=args.resume, resume=args.resume,
cpu=args.cpu cpu=args.cpu
...@@ -39,16 +40,14 @@ class Model(nn.Module): ...@@ -39,16 +40,14 @@ class Model(nn.Module):
def forward(self, x, idx_scale): def forward(self, x, idx_scale):
self.idx_scale = idx_scale self.idx_scale = idx_scale
target = self.get_model() target = self.get_model()
if hasattr(target, 'set_scale'): if hasattr(target, 'set_scale'): target.set_scale(idx_scale)
target.set_scale(idx_scale)
if self.self_ensemble and not self.training: if self.self_ensemble and not self.training:
if self.chop: if self.chop:
forward_function = self.forward_chop forward_function = self.forward_chop
else: else:
forward_function = self.model.forward forward_function = self.model.forward
return self.forward_x8(x, forward_function) return self.forward_x8(x, forward_function=forward_function)
elif self.chop and not self.training: elif self.chop and not self.training:
return self.forward_chop(x) return self.forward_chop(x)
else: else:
...@@ -66,22 +65,17 @@ class Model(nn.Module): ...@@ -66,22 +65,17 @@ class Model(nn.Module):
def save(self, apath, epoch, is_best=False): def save(self, apath, epoch, is_best=False):
target = self.get_model() target = self.get_model()
torch.save( save_dirs = [os.path.join(apath, 'model_latest.pt')]
target.state_dict(),
os.path.join(apath, 'model', 'model_latest.pt')
)
if is_best:
torch.save(
target.state_dict(),
os.path.join(apath, 'model', 'model_best.pt')
)
if is_best:
save_dirs.append(os.path.join(apath, 'model_best.pt'))
if self.save_models: if self.save_models:
torch.save( save_dirs.append(
target.state_dict(), os.path.join(apath, 'model_{}.pt'.format(epoch))
os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
) )
for s in save_dirs: torch.save(target.state_dict(), s)
def load(self, apath, pre_train='.', resume=-1, cpu=False): def load(self, apath, pre_train='.', resume=-1, cpu=False):
if cpu: if cpu:
kwargs = {'map_location': lambda storage, loc: storage} kwargs = {'map_location': lambda storage, loc: storage}
...@@ -91,11 +85,10 @@ class Model(nn.Module): ...@@ -91,11 +85,10 @@ class Model(nn.Module):
load_from = None load_from = None
if resume == -1: if resume == -1:
load_from = torch.load( load_from = torch.load(
os.path.join(apath, 'model', 'model_latest.pt'), os.path.join(apath, 'model_latest.pt'),
**kwargs **kwargs
) )
elif resume == 0: elif resume == 0:
if pre_train != '.':
if pre_train == 'download': if pre_train == 'download':
print('Download the model') print('Download the model')
dir_model = os.path.join('..', 'models') dir_model = os.path.join('..', 'models')
...@@ -105,17 +98,16 @@ class Model(nn.Module): ...@@ -105,17 +98,16 @@ class Model(nn.Module):
model_dir=dir_model, model_dir=dir_model,
**kwargs **kwargs
) )
else: elif pre_train != '':
print('Load the model from {}'.format(pre_train)) print('Load the model from {}'.format(pre_train))
load_from = torch.load(pre_train, **kwargs) load_from = torch.load(pre_train, **kwargs)
else: else:
load_from = torch.load( load_from = torch.load(
os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), os.path.join(apath, 'model_{}.pt'.format(resume)),
**kwargs **kwargs
) )
if load_from: if load_from: self.get_model().load_state_dict(load_from, strict=False)
self.get_model().load_state_dict(load_from, strict=False)
def forward_chop(self, *args, shave=10, min_size=160000): def forward_chop(self, *args, shave=10, min_size=160000):
if self.input_large: if self.input_large:
...@@ -152,8 +144,7 @@ class Model(nn.Module): ...@@ -152,8 +144,7 @@ class Model(nn.Module):
if not list_y: if not list_y:
list_y = [[_y] for _y in y] list_y = [[_y] for _y in y]
else: else:
for _list_y, _y in zip(list_y, y): for _list_y, _y in zip(list_y, y): _list_y.append(_y)
_list_y.append(_y)
h, w = scale * h, scale * w h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half h_half, w_half = scale * h_half, scale * w_half
...@@ -196,8 +187,7 @@ class Model(nn.Module): ...@@ -196,8 +187,7 @@ class Model(nn.Module):
list_x = [] list_x = []
for a in args: for a in args:
x = [a] x = [a]
for tf in 'v', 'h', 't': for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
x.extend([_transform(_x, tf) for _x in x])
list_x.append(x) list_x.append(x)
...@@ -208,8 +198,7 @@ class Model(nn.Module): ...@@ -208,8 +198,7 @@ class Model(nn.Module):
if not list_y: if not list_y:
list_y = [[_y] for _y in y] list_y = [[_y] for _y in y]
else: else:
for _list_y, _y in zip(list_y, y): for _list_y, _y in zip(list_y, y): _list_y.append(_y)
_list_y.append(_y)
for _list_y in list_y: for _list_y in list_y:
for i in range(len(_list_y)): for i in range(len(_list_y)):
......
...@@ -12,24 +12,22 @@ def default_conv(in_channels, out_channels, kernel_size, bias=True): ...@@ -12,24 +12,22 @@ def default_conv(in_channels, out_channels, kernel_size, bias=True):
padding=(kernel_size//2), bias=bias) padding=(kernel_size//2), bias=bias)
class MeanShift(nn.Conv2d): class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1) super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std) std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False self.requires_grad = False
class BasicBlock(nn.Sequential): class BasicBlock(nn.Sequential):
def __init__( def __init__(
self, in_channels, out_channels, kernel_size, stride=1, bias=False, self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=True, act=nn.ReLU(True)): bn=True, act=nn.ReLU(True)):
m = [nn.Conv2d( m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), stride=stride, bias=bias)
]
if bn: m.append(nn.BatchNorm2d(out_channels)) if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act) if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m) super(BasicBlock, self).__init__(*m)
......
...@@ -18,17 +18,14 @@ class EDSR(nn.Module): ...@@ -18,17 +18,14 @@ class EDSR(nn.Module):
def __init__(self, args, conv=common.default_conv): def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__() super(EDSR, self).__init__()
n_resblock = args.n_resblocks n_resblocks = args.n_resblocks
n_feats = args.n_feats n_feats = args.n_feats
kernel_size = 3 kernel_size = 3
scale = args.scale[0] scale = args.scale[0]
act = nn.ReLU(True) act = nn.ReLU(True)
self.url = url['r{}f{}x{}'.format(n_resblocks, n_feats, scale)] self.url = url['r{}f{}x{}'.format(n_resblocks, n_feats, scale)]
self.sub_mean = common.MeanShift(args.rgb_range)
rgb_mean = (0.4488, 0.4371, 0.4040) self.add_mean = common.MeanShift(args.rgb_range, sign=1)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
# define head module # define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)] m_head = [conv(args.n_colors, n_feats, kernel_size)]
...@@ -37,7 +34,7 @@ class EDSR(nn.Module): ...@@ -37,7 +34,7 @@ class EDSR(nn.Module):
m_body = [ m_body = [
common.ResBlock( common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblock) ) for _ in range(n_resblocks)
] ]
m_body.append(conv(n_feats, n_feats, kernel_size)) m_body.append(conv(n_feats, n_feats, kernel_size))
......
...@@ -16,14 +16,11 @@ class MDSR(nn.Module): ...@@ -16,14 +16,11 @@ class MDSR(nn.Module):
n_resblocks = args.n_resblocks n_resblocks = args.n_resblocks
n_feats = args.n_feats n_feats = args.n_feats
kernel_size = 3 kernel_size = 3
act = nn.ReLU(True)
self.scale_idx = 0 self.scale_idx = 0
self.url = url['r{}f{}'.format(n_resblocks, n_feats)] self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
self.sub_mean = common.MeanShift(args.rgb_range)
act = nn.ReLU(True) self.add_mean = common.MeanShift(args.rgb_range, sign=1)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
m_head = [conv(args.n_colors, n_feats, kernel_size)] m_head = [conv(args.n_colors, n_feats, kernel_size)]
...@@ -47,8 +44,6 @@ class MDSR(nn.Module): ...@@ -47,8 +44,6 @@ class MDSR(nn.Module):
m_tail = [conv(n_feats, args.n_colors, kernel_size)] m_tail = [conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head) self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body) self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail) self.tail = nn.Sequential(*m_tail)
......
...@@ -79,9 +79,7 @@ class RCAN(nn.Module): ...@@ -79,9 +79,7 @@ class RCAN(nn.Module):
act = nn.ReLU(True) act = nn.ReLU(True)
# RGB mean for DIV2K # RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040) self.sub_mean = common.MeanShift(args.rgb_range)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module # define head module
modules_head = [conv(args.n_colors, n_feats, kernel_size)] modules_head = [conv(args.n_colors, n_feats, kernel_size)]
...@@ -99,7 +97,7 @@ class RCAN(nn.Module): ...@@ -99,7 +97,7 @@ class RCAN(nn.Module):
common.Upsampler(conv, scale, n_feats, act=False), common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)] conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.add_mean = common.MeanShift(args.rgb_range, sign=1)
self.head = nn.Sequential(*modules_head) self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body) self.body = nn.Sequential(*modules_body)
......
from model import common
import torch.nn as nn
url = {
'r20f64': ''
}
def make_model(args, parent=False):
return VDSR(args)
class VDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(VDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
# define body module
m_body = []
m_body.append(common.BasicBlock(
conv, args.n_colors, n_feats, kernel_size, bn=False
))
for _ in range(n_resblocks - 2):
m_body.append(common.BasicBlock(
conv, n_feats, n_feats, kernel_size, bn=False
))
m_body.append(common.BasicBlock(
conv, n_feats, args.n_colors, kernel_size, bn=False, act=None
))
self.body = nn.Sequential(*m_body)
def forward(self, x):
x = self.sub_mean(x)
res = self.body(x)
res += x
x = self.add_mean(res)
return x
...@@ -50,7 +50,7 @@ parser.add_argument('--model', default='EDSR', ...@@ -50,7 +50,7 @@ parser.add_argument('--model', default='EDSR',
parser.add_argument('--act', type=str, default='relu', parser.add_argument('--act', type=str, default='relu',
help='activation function') help='activation function')
parser.add_argument('--pre_train', type=str, default='.', parser.add_argument('--pre_train', type=str, default='',
help='pre-trained model directory') help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.', parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory') help='pre-trained model directory')
......
...@@ -5,7 +5,6 @@ from decimal import Decimal ...@@ -5,7 +5,6 @@ from decimal import Decimal
import utility import utility
import torch import torch
from torch.autograd import Variable
from tqdm import tqdm from tqdm import tqdm
class Trainer(): class Trainer():
......
...@@ -64,27 +64,30 @@ class checkpoint(): ...@@ -64,27 +64,30 @@ class checkpoint():
if not os.path.exists(path): os.makedirs(path) if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir) _make_dir(self.dir)
_make_dir(self.dir + '/model') _make_dir(self.get_path('model'))
_make_dir(self.dir + '/results') _make_dir(self.get_path('results'))
open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.dir + '/log.txt', open_type) self.log_file = open(self.get_path('log.txt'), open_type)
with open(self.dir + '/config.txt', open_type) as f: with open(self.get_path('config.txt'), open_type) as f:
f.write(now + '\n\n') f.write(now + '\n\n')
for arg in vars(args): for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n') f.write('\n')
def get_path(self, *subdir):
return os.path.join(self.dir, *subdir)
def save(self, trainer, epoch, is_best=False): def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.dir, epoch, is_best=is_best) trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
trainer.loss.save(self.dir) trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch) trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch) self.plot_psnr(epoch)
torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) torch.save(self.log, self.get_path('psnr_log.pt'))
torch.save( torch.save(
trainer.optimizer.state_dict(), trainer.optimizer.state_dict(),
os.path.join(self.dir, 'optimizer.pt') self.get_path('optimizer.pt')
) )
def add_log(self, log): def add_log(self, log):
...@@ -95,7 +98,7 @@ class checkpoint(): ...@@ -95,7 +98,7 @@ class checkpoint():
self.log_file.write(log + '\n') self.log_file.write(log + '\n')
if refresh: if refresh:
self.log_file.close() self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a') self.log_file = open(self.get_path('log.txt'), 'a')
def done(self): def done(self):
self.log_file.close() self.log_file.close()
...@@ -115,14 +118,14 @@ class checkpoint(): ...@@ -115,14 +118,14 @@ class checkpoint():
plt.xlabel('Epochs') plt.xlabel('Epochs')
plt.ylabel('PSNR') plt.ylabel('PSNR')
plt.grid(True) plt.grid(True)
plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test)))
plt.close(fig) plt.close(fig)
def save_results(self, filename, save_list, scale): def save_results(self, filename, save_list, scale):
filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) filename = self.get_path('results', '{}_x{}_'.format(filename, scale))
postfix = ('SR', 'LR', 'HR') postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix): for v, p in zip(save_list, postfix):
normalized = v[0].data.mul(255 / self.args.rgb_range) normalized = v[0].mul(255 / self.args.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
misc.imsave('{}{}.png'.format(filename, p), ndarr) misc.imsave('{}{}.png'.format(filename, p), ndarr)
......
import os
import math
import utility
from data import common
import torch
import cv2
from tqdm import tqdm
class VideoTester():
def __init__(self, args, my_model, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.model = my_model
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self):
self.ckp.write_log('\nEvaluation on video:')
self.model.eval()
timer_test = utility.timer()
torch.set_grad_enabled(False)
for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
total_batches = math.ceil(total_frames / self.args.batch_size)
vidwri = cv2.VideoWriter(
self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
cv2.VideoWriter_fourcc(*'XVID'),
int(vidcap.get(cv2.CAP_PROP_FPS)),
(
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
)
)
tqdm_test = tqdm(range(total_batches), ncols=80)
for _ in tqdm_test:
fs = []
for _ in range(self.args.batch_size):
success, lr = vidcap.read()
if success:
fs.append(lr)
else:
break
fs = common.set_channel(*fs, n_channels=self.args.n_colors)
fs = common.np2Tensor(*fs, rgb_range=self.args.rgb_range)
lr = torch.stack(fs, dim=0)
lr, = self.prepare(lr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
for i in range(self.args.batch_size):
normalized = sr[i].mul(255 / self.args.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
vidwri.write(ndarr)
self.vidcap.release()
self.vidwri.release()
self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment