Skip to content
Snippets Groups Projects
Commit a91ccea9 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

now testing jpeg augmentation

parent d99e5756
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
from importlib import import_module from importlib import import_module
from dataloader import MSDataLoader from dataloader import MSDataLoader
from torch.utils.data.dataloader import default_collate from torch.utils.data import ConcatDataset
# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
def __init__(self, datasets):
super(MyConcatDataset, self).__init__(datasets)
def set_scale(self, idx_scale):
for d in self.datasets:
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
class Data: class Data:
def __init__(self, args): def __init__(self, args):
self.loader_train = None self.loader_train = None
if not args.test_only: if not args.test_only:
module_train = import_module('data.' + args.data_train.lower()) datasets = []
trainset = getattr(module_train, args.data_train)(args) for d in args.data_train:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
datasets.append(getattr(m, module_name)(args, name=d))
self.loader_train = MSDataLoader( self.loader_train = MSDataLoader(
args, args,
trainset, MyConcatDataset(datasets),
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
pin_memory=not args.cpu pin_memory=not args.cpu
) )
if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']: self.loader_test = []
module_test = import_module('data.benchmark') for d in args.data_test:
testset = getattr(module_test, 'Benchmark')( if d in ['Set5', 'Set14', 'B100', 'Urban100']:
args, train=False, name=args.data_test m = import_module('data.benchmark')
) testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else: else:
module_test = import_module('data.' + args.data_test.lower()) module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
testset = getattr(module_test, args.data_test)(args, train=False) m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test = MSDataLoader( self.loader_test.append(MSDataLoader(
args, args,
testset, testset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
pin_memory=not args.cpu pin_memory=not args.cpu
) ))
...@@ -15,7 +15,6 @@ class Demo(data.Dataset): ...@@ -15,7 +15,6 @@ class Demo(data.Dataset):
self.scale = args.scale self.scale = args.scale
self.idx_scale = 0 self.idx_scale = 0
self.train = False self.train = False
self.do_eval = False
self.benchmark = benchmark self.benchmark = benchmark
self.filelist = [] self.filelist = []
...@@ -25,8 +24,7 @@ class Demo(data.Dataset): ...@@ -25,8 +24,7 @@ class Demo(data.Dataset):
self.filelist.sort() self.filelist.sort()
def __getitem__(self, idx): def __getitem__(self, idx):
filename = os.path.split(self.filelist[idx])[-1] filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
filename, _ = os.path.splitext(filename)
lr = imageio.imread(self.filelist[idx]) lr = imageio.imread(self.filelist[idx])
lr, = common.set_channel(lr, n_channels=self.args.n_colors) lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
......
...@@ -3,6 +3,16 @@ from data import srdata ...@@ -3,6 +3,16 @@ from data import srdata
class DIV2K(srdata.SRData): class DIV2K(srdata.SRData):
def __init__(self, args, name='DIV2K', train=True, benchmark=False): def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
super(DIV2K, self).__init__( super(DIV2K, self).__init__(
args, name=name, train=train, benchmark=benchmark args, name=name, train=train, benchmark=benchmark
) )
...@@ -17,8 +27,6 @@ class DIV2K(srdata.SRData): ...@@ -17,8 +27,6 @@ class DIV2K(srdata.SRData):
def _set_filesystem(self, dir_data): def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data) super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
if self.input_large:
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubicL')
else:
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
if self.input_large: self.dir_lr += 'L'
import os
from data import srdata
from data import div2k
class DIV2KJPEG(div2k.DIV2K):
def __init__(self, args, name='', train=True, benchmark=False):
self.q_factor = int(name.replace('DIV2K-Q', ''))
super(DIV2KJPEG, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'DIV2K')
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
)
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.jpg')
...@@ -22,16 +22,6 @@ class SRData(data.Dataset): ...@@ -22,16 +22,6 @@ class SRData(data.Dataset):
self.scale = args.scale self.scale = args.scale
self.idx_scale = 0 self.idx_scale = 0
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
self._set_filesystem(args.dir_data) self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0: if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin') path_bin = os.path.join(self.apath, 'bin')
...@@ -85,8 +75,9 @@ class SRData(data.Dataset): ...@@ -85,8 +75,9 @@ class SRData(data.Dataset):
) )
if train: if train:
self.repeat \ n_patches = args.batch_size * args.test_every
= args.test_every // (len(self.images_hr) // args.batch_size) n_images = len(args.data_train) * len(self.images_hr)
self.repeat = max(n_patches // n_images, 1)
# Below functions as used to prepare images # Below functions as used to prepare images
def _scan(self): def _scan(self):
...@@ -106,10 +97,10 @@ class SRData(data.Dataset): ...@@ -106,10 +97,10 @@ class SRData(data.Dataset):
return names_hr, names_lr return names_hr, names_lr
def _set_filesystem(self, dir_data): def _set_filesystem(self, dir_data):
bicubic_type = 'LR_bicubic' if not self.input_large else 'LR_bicubicL'
self.apath = os.path.join(dir_data, self.name) self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR') self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, bicubic_type) self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.png') self.ext = ('.png', '.png')
def _name_hrbin(self): def _name_hrbin(self):
......
# EDSR baseline model (x2) # EDSR baseline model (x2)
#python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_baseline_x2 --reset #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset
python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
# EDSR baseline model (x3) - from EDSR baseline model (x2) # EDSR baseline model (x3) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
# EDSR baseline model (x4) - from EDSR baseline model (x2) # EDSR baseline model (x4) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] #python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
# EDSR in the paper (x2) # EDSR in the paper (x2)
#python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset #python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
# EDSR in the paper (x3) - from EDSR (x2) # EDSR in the paper (x3) - from EDSR (x2)
#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] #python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]
# EDSR in the paper (x4) - from EDSR (x2) # EDSR in the paper (x4) - from EDSR (x2)
#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] #python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
# MDSR baseline model # MDSR baseline model
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
...@@ -25,11 +26,7 @@ ...@@ -25,11 +26,7 @@
# Standard benchmarks (Ex. EDSR_baseline_x4) # Standard benchmarks (Ex. EDSR_baseline_x4)
#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
#python main.py --data_test DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
# Test your own images # Test your own images
#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
...@@ -38,7 +35,7 @@ ...@@ -38,7 +35,7 @@
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
# Advanced - Training with adversarial loss # Advanced - Training with adversarial loss
#python main.py --template GAN --scale 4 --save EDSR_GAN --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download #python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
# RDN BI model (x2) # RDN BI model (x2)
#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
......
...@@ -144,11 +144,15 @@ parser.add_argument('--print_every', type=int, default=100, ...@@ -144,11 +144,15 @@ parser.add_argument('--print_every', type=int, default=100,
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true', parser.add_argument('--save_results', action='store_true',
help='save output results') help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
args = parser.parse_args() args = parser.parse_args()
template.set_template(args) template.set_template(args)
args.scale = list(map(lambda x: int(x), args.scale.split('+'))) args.scale = list(map(lambda x: int(x), args.scale.split('+')))
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0: if args.epochs == 0:
args.epochs = 1e8 args.epochs = 1e8
......
...@@ -74,61 +74,58 @@ class Trainer(): ...@@ -74,61 +74,58 @@ class Trainer():
self.error_last = self.loss.log[-1, -1] self.error_last = self.loss.log[-1, -1]
def test(self): def test(self):
torch.set_grad_enabled(False)
epoch = self.scheduler.last_epoch + 1 epoch = self.scheduler.last_epoch + 1
self.ckp.write_log('\nEvaluation:') self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(torch.zeros(1, len(self.scale))) self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval() self.model.eval()
timer_test = utility.timer() timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background() if self.args.save_results: self.ckp.begin_background()
with torch.no_grad(): for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale): for idx_scale, scale in enumerate(self.scale):
eval_acc = 0 d.dataset.set_scale(idx_scale)
self.loader_test.dataset.set_scale(idx_scale) for lr, hr, filename, _ in tqdm(d, ncols=80):
tqdm_test = tqdm(self.loader_test, ncols=80)
for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
filename = filename[0]
no_eval = (hr.nelement() == 1)
lr, hr = self.prepare(lr, hr) lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale) sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range) sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr] save_list = [sr]
if not no_eval: self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
eval_acc += utility.calc_psnr( sr, hr, scale, self.args.rgb_range, dataset=d
sr, hr, scale, self.args.rgb_range,
benchmark=self.loader_test.dataset.benchmark
) )
save_list.extend([lr, hr]) if self.args.save_gt: save_list.extend([lr, hr])
if self.args.save_results: self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.save_results(filename, save_list, scale)
self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0) best = self.ckp.log.max(0)
self.ckp.write_log( self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
self.args.data_test, d.dataset.name,
scale, scale,
self.ckp.log[-1, idx_scale], self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_scale], best[0][idx_data, idx_scale],
best[1][idx_scale] + 1 best[1][idx_data, idx_scale] + 1
) )
) )
self.ckp.write_log( self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
'Forward time: {:.2f}s\n'.format(timer_test.toc())
)
self.ckp.write_log('Saving...') self.ckp.write_log('Saving...')
if self.args.save_results: self.ckp.end_background() if self.args.save_results: self.ckp.end_background()
if not self.args.test_only: if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log( self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
) )
torch.set_grad_enabled(True)
def prepare(self, *args): def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda') device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor): def _prepare(tensor):
......
...@@ -50,25 +50,23 @@ class checkpoint(): ...@@ -50,25 +50,23 @@ class checkpoint():
if args.load == '.': if args.load == '.':
if args.save == '.': args.save = now if args.save == '.': args.save = now
self.dir = '../experiment/' + args.save self.dir = os.path.join('..', 'experiment', args.save)
else: else:
self.dir = '../experiment/' + args.load self.dir = os.path.join('..', 'experiment', args.load)
if not os.path.exists(self.dir): if not os.path.exists(self.dir):
args.load = '.' args.load = '.'
else: else:
self.log = torch.load(self.dir + '/psnr_log.pt') self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log))) print('Continue from epoch {}...'.format(len(self.log)))
if args.reset: if args.reset:
os.system('rm -rf ' + self.dir) os.system('rm -rf ' + self.dir)
args.load = '.' args.load = '.'
def _make_dir(path): os.makedirs(self.dir, exist_ok=True)
if not os.path.exists(path): os.makedirs(path) os.makedirs(self.get_path('model'), exist_ok=True)
for d in args.data_test:
_make_dir(self.dir) os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
_make_dir(self.get_path('model'))
_make_dir(self.get_path('results'))
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type) self.log_file = open(self.get_path('log.txt'), open_type)
...@@ -110,20 +108,21 @@ class checkpoint(): ...@@ -110,20 +108,21 @@ class checkpoint():
def plot_psnr(self, epoch): def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch) axis = np.linspace(1, epoch, epoch)
label = 'SR on {}'.format(self.args.data_test) for idx_data, d in enumerate(self.args.data_test):
label = 'SR on {}'.format(d)
fig = plt.figure() fig = plt.figure()
plt.title(label) plt.title(label)
for idx_scale, scale in enumerate(self.args.scale): for idx_scale, scale in enumerate(self.args.scale):
plt.plot( plt.plot(
axis, axis,
self.log[:, idx_scale].numpy(), self.log[:, idx_data, idx_scale].numpy(),
label='Scale {}'.format(scale) label='Scale {}'.format(scale)
) )
plt.legend() plt.legend()
plt.xlabel('Epochs') plt.xlabel('Epochs')
plt.ylabel('PSNR') plt.ylabel('PSNR')
plt.grid(True) plt.grid(True)
plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test))) plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig) plt.close(fig)
def begin_background(self): def begin_background(self):
...@@ -148,8 +147,13 @@ class checkpoint(): ...@@ -148,8 +147,13 @@ class checkpoint():
while not self.queue.empty(): time.sleep(1) while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join() for p in self.process: p.join()
def save_results(self, filename, save_list, scale): def save_results(self, dataset, filename, save_list, scale):
filename = self.get_path('results', '{}_x{}_'.format(filename, scale)) if self.args.save_results:
filename = self.get_path(
'results-{}'.format(dataset.dataset.name),
'{}_x{}_'.format(filename, scale)
)
postfix = ('SR', 'LR', 'HR') postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix): for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range) normalized = v[0].mul(255 / self.args.rgb_range)
...@@ -160,16 +164,18 @@ def quantize(img, rgb_range): ...@@ -160,16 +164,18 @@ def quantize(img, rgb_range):
pixel_range = 255 / rgb_range pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
diff = (sr - hr).data.div(rgb_range) if hr.nelement() == 1: return 0
if benchmark:
diff = (sr - hr) / rgb_range
if dataset and dataset.dataset.benchmark:
shave = scale shave = scale
if diff.size(1) > 1: if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1) convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738 convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057 convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064 convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256) diff *= (convert / 256)
diff = diff.sum(dim=1, keepdim=True) diff = diff.sum(dim=1, keepdim=True)
else: else:
shave = scale + 6 shave = scale + 6
......
...@@ -20,11 +20,12 @@ class VideoTester(): ...@@ -20,11 +20,12 @@ class VideoTester():
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self): def test(self):
torch.set_grad_enabled(False)
self.ckp.write_log('\nEvaluation on video:') self.ckp.write_log('\nEvaluation on video:')
self.model.eval() self.model.eval()
timer_test = utility.timer() timer_test = utility.timer()
torch.set_grad_enabled(False)
for idx_scale, scale in enumerate(self.scale): for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo) vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
...@@ -57,8 +58,9 @@ class VideoTester(): ...@@ -57,8 +58,9 @@ class VideoTester():
vidwri.release() vidwri.release()
self.ckp.write_log( self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
) )
torch.set_grad_enabled(True)
def prepare(self, *args): def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda') device = torch.device('cpu' if self.args.cpu else 'cuda')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment