Skip to content
Snippets Groups Projects
Commit 288f32f3 authored by Sanhyun Son's avatar Sanhyun Son
Browse files

tqdm for evaluation

parent 15169c60
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -33,6 +33,10 @@ We provide scripts for reproducing all the results from our paper. You can train ...@@ -33,6 +33,10 @@ We provide scripts for reproducing all the results from our paper. You can train
## Dependencies ## Dependencies
* Python (Tested with 3.6) * Python (Tested with 3.6)
* PyTorch >= 0.3.1 * PyTorch >= 0.3.1
* numpy
* scipy
* matplotlib
* tqdm
## Code ## Code
Clone this repository into any place you want. Clone this repository into any place you want.
......
...@@ -11,16 +11,15 @@ import torch.utils.data as data ...@@ -11,16 +11,15 @@ import torch.utils.data as data
class Demo(data.Dataset): class Demo(data.Dataset):
def __init__(self, args, train=False): def __init__(self, args, train=False):
self.args = args self.args = args
self.train = False self.name = 'Demo'
self.name = 'MyImage'
self.scale = args.scale self.scale = args.scale
self.idx_scale = 0 self.idx_scale = 0
apath = '../test' self.train = False
self.filelist = [] self.filelist = []
for f in os.listdir(apath): for f in os.listdir(args.dir_demo):
if f.find('.png') >= 0 or f.find('.jp') >= 0: if f.find('.png') >= 0 or f.find('.jp') >= 0:
self.filelist.append(os.path.join(apath, f)) self.filelist.append(os.path.join(args.dir_demo, f))
self.filelist.sort() self.filelist.sort()
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -36,3 +35,4 @@ class Demo(data.Dataset): ...@@ -36,3 +35,4 @@ class Demo(data.Dataset):
def set_scale(self, idx_scale): def set_scale(self, idx_scale):
self.idx_scale = idx_scale self.idx_scale = idx_scale
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#python main.py --data_test Set5 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test Set5 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test Set14 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test B100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble python main.py --data_test Urban100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test DIV2K --ext img --n_val 100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble #python main.py --data_test DIV2K --ext img --n_val 100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble #python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#python main.py --data_test DIV2K --ext img --n_val 100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble #python main.py --data_test DIV2K --ext img --n_val 100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
# Test your own images # Test your own images
python main.py --data_test Demo --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --save_results #python main.py --data_test Demo --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --save_results
# Advanced - Test with JPEG images # Advanced - Test with JPEG images
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train ../experiment/model/MDSR_baseline_jpeg.pt --test_only --save_results #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train ../experiment/model/MDSR_baseline_jpeg.pt --test_only --save_results
......
...@@ -13,7 +13,7 @@ checkpoint = utility.checkpoint(args) ...@@ -13,7 +13,7 @@ checkpoint = utility.checkpoint(args)
if checkpoint.ok: if checkpoint.ok:
loader = data.Data(args) loader = data.Data(args)
model = model.Model(args, checkpoint) model = model.Model(args, checkpoint)
loss = loss.Loss(args, checkpoint) loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, model, loss, checkpoint) t = Trainer(args, loader, model, loss, checkpoint)
while not t.terminate(): while not t.terminate():
t.train() t.train()
......
...@@ -21,6 +21,8 @@ parser.add_argument('--seed', type=int, default=1, ...@@ -21,6 +21,8 @@ parser.add_argument('--seed', type=int, default=1,
# Data specifications # Data specifications
parser.add_argument('--dir_data', type=str, default='../../../dataset', parser.add_argument('--dir_data', type=str, default='../../../dataset',
help='dataset directory') help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../test',
help='demo image directory')
parser.add_argument('--data_train', type=str, default='DIV2K', parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name') help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K', parser.add_argument('--data_test', type=str, default='DIV2K',
......
...@@ -6,6 +6,7 @@ import utility ...@@ -6,6 +6,7 @@ import utility
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from tqdm import tqdm
class Trainer(): class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp): def __init__(self, args, loader, my_model, my_loss, ckp):
...@@ -83,7 +84,8 @@ class Trainer(): ...@@ -83,7 +84,8 @@ class Trainer():
for idx_scale, scale in enumerate(self.scale): for idx_scale, scale in enumerate(self.scale):
eval_acc = 0 eval_acc = 0
self.loader_test.dataset.set_scale(idx_scale) self.loader_test.dataset.set_scale(idx_scale)
for idx_img, (lr, hr, _) in enumerate(self.loader_test): tqdm_test = tqdm(self.loader_test, ncols=80)
for idx_img, (lr, hr, _) in enumerate(tqdm_test):
no_eval = isinstance(hr[0], torch._six.string_classes) no_eval = isinstance(hr[0], torch._six.string_classes)
if no_eval: if no_eval:
lr = self.prepare([lr], volatile=True)[0] lr = self.prepare([lr], volatile=True)[0]
...@@ -108,34 +110,25 @@ class Trainer(): ...@@ -108,34 +110,25 @@ class Trainer():
self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
best = self.ckp.log.max(0) best = self.ckp.log.max(0)
performance = 'PSNR: {:.3f}'.format(
self.ckp.log[-1, idx_scale]
)
self.ckp.write_log( self.ckp.write_log(
'[{} x{}]\t{} (Best: {:.3f} from epoch {})'.format( '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} from epoch {})'.format(
self.args.data_test, scale, performance, self.args.data_test, scale, self.ckp.log[-1, idx_scale],
best[0][idx_scale], best[1][idx_scale] + 1 best[0][idx_scale], best[1][idx_scale] + 1
) )
) )
self.ckp.write_log( self.ckp.write_log(
'Time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
) )
if not self.args.test_only: if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
def prepare(self, l, volatile=False): def prepare(self, l, volatile=False):
def _prepare(idx, tensor): def _prepare(idx, tensor):
if not self.args.cpu: if not self.args.cpu: tensor = tensor.cuda()
tensor = tensor.cuda() if self.args.precision == 'half': tensor = tensor.half()
if self.args.precision == 'half':
tensor = tensor.half()
# Only test lr can be volatile # Only test lr can be volatile
var = Variable(tensor, volatile=(volatile and idx==0)) return Variable(tensor, volatile=(volatile and idx==0))
return var
return [_prepare(i, _l) for i, _l in enumerate(l)] return [_prepare(i, _l) for i, _l in enumerate(l)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment