Skip to content
Snippets Groups Projects
Commit f5f492b0 authored by Sanhyun Son's avatar Sanhyun Son
Browse files

Mar 29, 2018 updates

    We now provide all models from our paper.
    We also provide 'MDSR_baseline_jpeg' model that suppresses JPEG artifacts in original low-resolution image.
    Please use it if you have any trouble.
    'MyImage' dataset is changed to 'Demo' dataset.
    Also, it works more efficient than before.
    Some codes and script are re-written.
parent a9303eaf
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
# EDSR-PyTorch
![](/figs/main.png)
This repository is a PyTorch version of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017**.
This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**.
You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017).
If you find our work useful in your research or publication, please cite our work:
......@@ -16,30 +16,27 @@ If you find our work useful in your research or publication, please cite our wor
year = {2017}
}
```
This repository provides some demo codes for reproducing all the results from the paper. (Include training scripts.)
Also, pre-trained model will be uploaded soon.
We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images.
**Differences between Torch version**
* Codes are much more compact. (Removed all unnecessary parts.)
* Models are smaller. (About half in their sizes.)
* Models are smaller. (About half.)
* Slightly better performances.
* Training requires less memory.
* Test is faster.
* Training and evaluation requires less memory.
* Python-based.
**Recent updates**
* Mar 20, 2018
* Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
* Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
* Changed the behavior of skip_batch.
* Mar 29, 2018
* We now provide all models from our paper.
* We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble.
* ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.
* Some codes and script are re-written.
## Dependencies
* Python (Tested with 3.6)
* PyTorch >= 0.3.1
## Code
Clone this repository into any place you want.
```bash
git clone https://github.com/thstkdgus35/EDSR-PyTorch
......@@ -47,51 +44,49 @@ cd EDSR-PyTorch
```
## Quick start (Demo)
You can test our super-resolution algorithm with your own images.
Place your images in ```test``` folder. (like ```test/puppy.jpeg```)
Then, run the provided script in ```code``` folder.
You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files.
Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
Run the script in ``code`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
```bash
cd code # You are now in */EDSR-PyTorch/code
sh demo.sh
```
You can find the result images from ```experiment/test_<modelName>``` folder.
We provide some pre-trained models. (Not full version, baseline ONLY.) You can find the model from the ```experiment/model```.
Also, these models have better performance than the original Torch7 models.
| Model | Scale | File name | ResBlocks | Filters | Parameters | **PSNR (PyTorch)** | PSNR (Torch7) |
| --- | --- | --- | --- | --- | --- | --- | --- |
| **EDSR** | 2 | EDSR_baseline_x2.pt | 16 | 64 | 1.5M | 34.61 | 34.55 |
| **EDSR** | 3 | EDSR_baseline_x3.pt | 16 | 64 | 1.5M | 30.92 | 30.90 |
| **EDSR** | 4 | EDSR_baseline_x4.pt | 16 | 64 | 1.5M | 28.95 | 28.94 |
| **MDSR** | 2 | MDSR_baseline.pt | 16 | 64 | 3.2M | 34.63 | 34.60 |
| | 3 | | | | | 30.94 | 30.91 |
| | 4 | | | | | 28.97 | 28.95 |
You can find the result images from ```experiment/test/results``` folder.
*We measured PSNR using DIV2K 0801 ~ 0900
| Model | Scale | File name (.pt) | Parameters | **PSNR (PyTorch)** | PSNR (Torch7) |
| --- | --- | --- | --- | --- | --- |
| **EDSR** | 2 | EDSR_baseline_x2 | 1.5M | 34.61 | 34.55 |
| **EDSR** | 3 | EDSR_baseline_x3 | 1.5M | 30.92 | 30.90 |
| **EDSR** | 4 | EDSR_baseline_x4 | 1.5M | 28.95 | 28.94 |
| **MDSR** | 2 | MDSR_baseline | 3.2M | 34.63 | 34.60 |
| | 3 | | | 30.94 | 30.91 |
| | 4 | | | 28.97 | 28.95 |
*Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (486MB)
**We measured PSNR using DIV2K 0801 ~ 0900, without self-ensemble.
You can evaluate your models with widely-used benchmark:
[Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html),
[Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests),
[B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
[Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
For these datasets, we first convert the results images to YCbCr color space and use Y channel only. Please unpack [this file](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB) to any place you want. Then, set ``--dir_data <where_benchmark_folder_located>`` to evaluate the models.
For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. Download [this file](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB) and untar it to any place you want. Then, set ``--dir_data <where_benchmark_folder_located>`` to evaluate the models.
## How to train EDSR and MDSR
We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset for training. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).
We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).
Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```code/option.py``` to the place where you unpack DIV2K images.
Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```code/option.py``` to the place where DIV2K images are located.
We recommend you to pre-process the images before training. This step will decode and collect all png files into one huge binary file. Use ```code/tools/png2binary.py``` for this process.
We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument.
If you do not have enough RAM (>= 16GB), change the ```ext``` argument in ```code/option.py``` to ```png```. However, each image in DIV2K is so large that disk access and decoding png files can be a bottleneck.
If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file.
Training scripts are also included in ``demo.sh``. By uncommenting the appropriate line and executing the script, you can train EDSR and MDSR by yourself. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). By removing ```--pre_train``` argument in the provided script, you can ignore this constraint.
You can train EDSR and MDSR by yourself. All scripts are provided in the ``code/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train <x2 model>``` argument.
```bash
cd code # You are now in */EDSR-PyTorch/code
......@@ -140,3 +135,7 @@ sh demo.sh
* Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.
* Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected.
* Mar 20, 2018
* Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
* Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
* Changed the behavior of skip_batch.
import os
import os.path
import random
import math
import errno
from data import common
......@@ -11,9 +7,8 @@ import scipy.misc as misc
import torch
import torch.utils.data as data
from torchvision import transforms
class MyImage(data.Dataset):
class Demo(data.Dataset):
def __init__(self, args, train=False):
self.args = args
self.train = False
......@@ -23,29 +18,20 @@ class MyImage(data.Dataset):
apath = '../test'
self.filelist = []
if not train:
for f in os.listdir(apath):
try:
filename = os.path.join(apath, f)
misc.imread(filename)
self.filelist.append(filename)
except:
pass
if f.find('.png') >= 0 or f.find('.jp') >= 0:
self.filelist.append(os.path.join(apath, f))
def __getitem__(self, idx):
img_in = misc.imread(self.filelist[idx])
if len(img_in.shape) == 2:
img_in = np.expand_dims(img_in, 2)
filename = os.path.split(self.filelist[idx])[-1]
filename, _ = os.path.splitext(filename)
lr = misc.imread(self.filelist[idx])
lr = common.set_channel([lr], self.args.n_colors)[0]
img_in, img_tar = common.set_channel(img_in, img_in, self.args.n_colors)
img_tar = misc.imresize(
img_tar, self.scale[self.idx_scale] * 100, interp='bicubic')
return common.np2Tensor(img_in, img_tar, self.args.rgb_range)
return common.np2Tensor([lr], self.args.rgb_range)[0], filename
def __len__(self):
return len(self.filelist)
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
......@@ -9,8 +9,7 @@ import torch
from torchvision import transforms
def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False):
ih, iw, c = img_in.shape
th, tw = scale * ih, scale * iw
ih, iw = img_in.shape[:2]
p = scale if multi_scale else 1
tp = p * patch_size
......@@ -25,14 +24,12 @@ def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False):
return img_in, img_tar
def set_channel(img_in, img_tar, n_channel):
if img_tar.ndim == 2:
img_in = np.expand_dims(img_in, axis=2)
img_tar = np.expand_dims(img_tar, axis=2)
h, w, c = img_tar.shape
def set_channel(l, n_channel):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channel == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channel == 3 and c == 1:
......@@ -40,19 +37,19 @@ def set_channel(img_in, img_tar, n_channel):
return img
return _set_channel(img_in), _set_channel(img_tar)
return [_set_channel(_l) for _l in l]
def np2Tensor(img_in, img_tar, rgb_range):
def np2Tensor(l, rgb_range):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
torch_tensor = torch.from_numpy(np_transpose).float()
torch_tensor.mul_(rgb_range / 255)
tensor = torch.from_numpy(np_transpose).float()
tensor.mul_(rgb_range / 255)
return torch_tensor
return tensor
return _np2Tensor(img_in), _np2Tensor(img_tar)
return [_np2Tensor(_l) for _l in l]
def augment(img_in, img_tar, hflip=True, rot=True):
def augment(l, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
......@@ -64,5 +61,4 @@ def augment(img_in, img_tar, hflip=True, rot=True):
return img
return _augment(img_in), _augment(img_tar)
return [_augment(_l) for _l in l]
......@@ -30,14 +30,14 @@ class SRData(data.Dataset):
if args.ext.find('reset') >= 0:
print('Preparing seperated binary files')
for v in self.images_hr:
img_hr = misc.imread(v)
hr = misc.imread(v)
name_sep = v.replace(self.ext, '.npy')
np.save(name_sep, img_hr)
np.save(name_sep, hr)
for si, s in enumerate(self.scale):
for v in self.images_lr[si]:
img_lr = misc.imread(v)
lr = misc.imread(v)
name_sep = v.replace(self.ext, '.npy')
np.save(name_sep, img_lr)
np.save(name_sep, lr)
self.images_hr = [
v.replace(self.ext, '.npy') for v in self.images_hr
......@@ -84,12 +84,11 @@ class SRData(data.Dataset):
raise NotImplementedError
def __getitem__(self, idx):
img_lr, img_hr = self._load_file(idx)
img_lr, img_hr = self._get_patch(img_lr, img_hr)
img_lr, img_hr = common.set_channel(
img_lr, img_hr, self.args.n_colors)
lr, hr = self._load_file(idx)
lr, hr = self._get_patch(lr, hr)
lr, hr = common.set_channel([lr, hr], self.args.n_colors)
return common.np2Tensor(img_lr, img_hr, self.args.rgb_range)
return common.np2Tensor([lr, hr], self.args.rgb_range)
def __len__(self):
return len(self.images_hr)
......@@ -99,32 +98,31 @@ class SRData(data.Dataset):
def _load_file(self, idx):
idx = self._get_index(idx)
img_lr = self.images_lr[self.idx_scale][idx]
img_hr = self.images_hr[idx]
lr = self.images_lr[self.idx_scale][idx]
hr = self.images_hr[idx]
if self.args.ext == 'img':
img_lr = misc.imread(img_lr)
img_hr = misc.imread(img_hr)
lr = misc.imread(lr)
hr = misc.imread(hr)
elif self.args.ext.find('sep') >= 0:
img_lr = np.load(img_lr)
img_hr = np.load(img_hr)
lr = np.load(lr)
hr = np.load(hr)
return img_lr, img_hr
return lr, hr
def _get_patch(self, img_lr, img_hr):
def _get_patch(self, lr, hr):
patch_size = self.args.patch_size
scale = self.scale[self.idx_scale]
multi_scale = len(self.scale) > 1
if self.train:
img_lr, img_hr = common.get_patch(
img_lr, img_hr, patch_size, scale, multi_scale=multi_scale)
img_lr, img_hr = common.augment(img_lr, img_hr)
lr, hr = common.get_patch(
lr, hr, patch_size, scale, multi_scale=multi_scale
)
lr, hr = common.augment([lr, hr])
else:
ih = img_lr.shape[0]
iw = img_lr.shape[1]
img_hr = img_hr[0:ih * scale, 0:iw * scale]
ih, iw = lr.shape[0:2]
hr = hr[0:ih * scale, 0:iw * scale]
return img_lr, img_hr
return lr, hr
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
# EDSR baseline model (x2)
#python main.py --model EDSR --scale 2 --pre_train ../experiment/model/EDSR_baseline_x2.pt --reset --data_test Set5 --test_only
#python main.py --model EDSR --scale 2 --save EDSR_baseline_x2 --reset --ext sep
#python main.py --model EDSR --scale 2 --pre_train ../experiment/model/EDSR_baseline_x2.pt --reset --test_only --n_val 100 --save_results
#python main.py --model EDSR --scale 2 --save EDSR_baseline_x2 --reset
# EDSR baseline model (x3) - requires pre-trained EDSR baseline x2 model
# EDSR baseline model (x3) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 3 --save EDSR_baseline_x3 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt
#python main.py --model EDSR --scale 3 --pre_train ../experiment/model/EDSR_baseline_x3.pt --reset --test_only --n_val 100
# EDSR baseline model (x4) - requires pre-trained EDSR baseline x2 model
# EDSR baseline model (x4) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 4 --save EDSR_baseline_x4 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt
#python main.py --model EDSR --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --reset --test_only --n_val 100
# EDSR in the paper (x2)
#python main.py --model EDSR --scale 2 --save EDSR_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
# EDSR in the paper (x3) - requires pre-trained EDSR baseline x2 model
# EDSR in the paper (x3) - from EDSR (x2)
#python main.py --model EDSR --scale 3 --save EDSR_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train ../experiment/EDSR_x2/model/model_best.pt
# EDSR in the paper (x4) - requires pre-trained EDSR baseline x2 model
# EDSR in the paper (x4) - from EDSR (x2)
#python main.py --model EDSR --scale 4 --save EDSR_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train ../experiment/EDSR_x2/model/model_best.pt
# MDSR baseline model
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
python main.py --template MDSR --model MDSR --scale 2+3+4 --reset --test_only --n_val 10 --pre_train ../experiment/model/MDSR_baseline_jpeg_modified.pt
# MDSR in the paper
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR --n_resblocks 80 --reset
#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
# Standard benchmarks (Ex. EDSR_baseline_x4)
#python main.py --data_test Set5 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test DIV2K --ext img --n_val 100 --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --self_ensemble
#python main.py --data_test Set5 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
#python main.py --data_test Set14 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
#python main.py --data_test B100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
#python main.py --data_test Urban100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
#python main.py --data_test DIV2K --ext img --n_val 100 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../experiment/model/EDSR_x4.pt --test_only --self_ensemble
# Test your own images
#python main.py --scale 4 --data_test MyImage --test_only --save_results --pre_train ../experiment/model/EDSR_baseline_x4.pt --chop_forward
python main.py --data_test Demo --scale 4 --pre_train ../experiment/model/EDSR_baseline_x4.pt --test_only --save_results
# !!!Currently disabled!!!
# Advanced - JPEG artifact removal
#python main.py --template MDSR_jpeg --model MDSR --scale 2+3+4 --save MDSR_jpeg --quality 75+ --reset
# Advanced - Test with JPEG images
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train ../experiment/model/MDSR_baseline_jpeg.pt --test_only --save_results
import torch
import utils
import utility
from option import args
from data import data
from trainer import Trainer
torch.manual_seed(args.seed)
checkpoint = utils.checkpoint(args)
checkpoint = utility.checkpoint(args)
if checkpoint.ok:
my_loader = data(args).get_loader()
......
import math
import random
from decimal import Decimal
from functools import reduce
import utils
import utility
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.utils as tu
class Trainer():
def __init__(self, loader, ckp, args):
......@@ -45,23 +39,24 @@ class Trainer():
self.ckp.add_log(torch.zeros(1, len(self.loss)))
self.model.train()
timer_data, timer_model = utils.timer(), utils.timer()
for batch, (input, target, idx_scale) in enumerate(self.loader_train):
input, target = self._prepare(input, target)
timer_data, timer_model = utility.timer(), utility.timer()
for batch, (lr, hr, idx_scale) in enumerate(self.loader_train):
lr, hr = self.prepare([lr, hr])
self._scale_change(idx_scale)
timer_data.hold()
timer_model.tic()
self.optimizer.zero_grad()
output = self.model(input)
loss = self._calc_loss(output, target)
sr = self.model(lr)
loss = self._calc_loss(sr, hr)
if loss.data[0] < self.args.skip_threshold * self.error_last:
loss.backward()
self.optimizer.step()
else:
print('Skip this batch {}! (Loss: {})'.format(
batch + 1, loss.data[0]))
batch + 1, loss.data[0]
))
timer_model.hold()
......@@ -87,66 +82,90 @@ class Trainer():
# We can use custom forward function
def _test_forward(x, scale):
if self.args.self_ensemble:
return utils.x8_forward(x, self.model, self.args.precision)
return utility.x8_forward(x, self.model, self.args.precision)
elif self.args.chop_forward:
return utils.chop_forward(x, self.model, scale)
return utility.chop_forward(x, self.model, scale)
else:
return self.model(x)
timer_test = utils.timer()
timer_test = utility.timer()
set_name = type(self.loader_test.dataset).__name__
for idx_scale, scale in enumerate(self.scale):
eval_acc = 0
self._scale_change(idx_scale, self.loader_test)
for idx_img, (input, target, _) in enumerate(self.loader_test):
input, target = self._prepare(input, target, volatile=True)
output = _test_forward(input, scale)
eval_acc += utils.calc_PSNR(
output, target, set_name, self.args.rgb_range, scale)
self.ckp.save_results(idx_img, input, output, target, scale)
for idx_img, (lr, hr, _) in enumerate(self.loader_test):
no_eval = isinstance(hr[0], torch._six.string_classes)
if no_eval:
lr = self.prepare([lr], volatile=True)[0]
filename = hr[0]
else:
lr, hr = self.prepare([lr, hr], volatile=True)
filename = idx_img + 1
sr = _test_forward(lr, scale)
sr = utility.quantize(sr, self.args.rgb_range)
if no_eval:
save_list = [sr]
else:
eval_acc += utility.calc_PSNR(
sr,
hr.div(self.args.rgb_range),
set_name,
scale
)
save_list = [sr, lr, hr]
if self.args.save_results:
self.ckp.save_results(filename, save_list, scale)
self.ckp.log_test[-1, idx_scale] = eval_acc / len(self.loader_test)
best = self.ckp.log_test.max(0)
performance = 'PSNR: {:.3f}'.format(
self.ckp.log_test[-1, idx_scale])
self.ckp.log_test[-1, idx_scale]
)
self.ckp.write_log(
'[{} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
set_name,
scale,
performance,
best[0][idx_scale],
best[1][idx_scale] + 1))
best[1][idx_scale] + 1
)
)
is_best = (best[1][0] + 1 == epoch)
self.ckp.write_log(
'Time: {:.2f}s\n'.format(timer_test.toc()), refresh=True)
'Time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
self.ckp.save(self, epoch, is_best=is_best)
def _prepare(self, input, target, volatile=False):
def prepare(self, l, volatile=False):
def _prepare(idx, tensor):
if not self.args.no_cuda:
input = input.cuda()
target = target.cuda()
tensor = tensor.cuda()
if self.args.precision == 'half':
input = input.half()
target = target.half()
tensor = tensor.half()
input = Variable(input, volatile=volatile)
target = Variable(target)
# Only test lr can be volatile
var = Variable(tensor, volatile=(volatile and idx==0))
return input, target
return var
def _calc_loss(self, output, target):
return [_prepare(i, _l) for i, _l in enumerate(l)]
def _calc_loss(self, sr, hr):
loss_list = []
for i, l in enumerate(self.loss):
if isinstance(output, list):
if isinstance(target, list):
loss = l['function'](output[i], target[i])
if isinstance(sr, list):
if isinstance(hr, list):
loss = l['function'](sr[i], hr[i])
else:
loss = l['function'](output[i], target)
loss = l['function'](sr[i], hr)
else:
loss = l['function'](output, target)
loss = l['function'](sr, hr)
loss_list.append(l['weight'] * loss)
self.ckp.log_training[-1, i] += loss.data[0]
......@@ -172,4 +191,3 @@ class Trainer():
else:
epoch = self.scheduler.last_epoch + 1
return epoch >= self.args.epochs
......@@ -87,7 +87,8 @@ class checkpoint():
optimizer_function = optim.Adam
kwargs = {
'betas': (self.args.beta1, self.args.beta2),
'eps': self.args.epsilon}
'eps': self.args.epsilon
}
elif self.args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop
kwargs = {'eps': self.args.epsilon}
......@@ -150,26 +151,22 @@ class checkpoint():
else:
state = trainer.model.state_dict()
torch.save(state, self.dir + '/model/model_latest.pt')
save_list = [(state, 'model/model_latest.pt')]
if not self.args.test_only:
if is_best:
torch.save(state, self.dir + '/model/model_best.pt')
save_list.append((state, 'model/model_best.pt'))
if self.args.save_models:
torch.save(
state,
'{}/model/model_{}.pt'.format(self.dir, epoch))
torch.save(trainer.loss, self.dir + '/loss.pt')
torch.save(
trainer.optimizer.state_dict(),
self.dir + '/optimizer.pt')
torch.save(
self.log_training,
self.dir + '/log_training.pt')
torch.save(
self.log_test,
self.dir + '/log_test.pt')
save_list.append((state, 'model/model_{}.pt'.format(epoch)))
save_list.append((trainer.loss, 'loss.pt'))
save_list.append((trainer.optimizer.state_dict(), 'optimizer.pt'))
save_list.append((self.log_training, 'log_training.pt'))
save_list.append((self.log_test, 'log_test.pt'))
self.plot(trainer, epoch, self.log_training, self.log_test)
for o, p in save_list:
torch.save(o, os.path.join(self.dir, p))
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
......@@ -213,38 +210,34 @@ class checkpoint():
fig,
'{}/test_{}.pdf'.format(self.dir, set_name))
def save_results(self, idx, input, output, target, scale):
rgb_range = self.args.rgb_range
output = quantize(output, rgb_range).mul(rgb_range)
if self.args.save_results:
filename = '{}/results/{}x{}_'.format(self.dir, idx + 1, scale)
for v, n in (input, 'LR'), (output, 'SR'), (target, 'GT'):
tu.save_image(
v.data[0] / self.args.rgb_range,
'{}{}.png'.format(filename, n))
def save_results(self, filename, save_list, scale):
filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
tu.save_image(v.data[0], '{}{}.png'.format(filename, p), padding=0)
def chop_forward(x, model, scale, shave=10, min_size=80000, n_GPUs=1):
n_GPUs = min(n_GPUs, 4)
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
input_list = [
lr_list = [
x[:, :, 0:h_size, 0:w_size],
x[:, :, 0:h_size, (w - w_size):w],
x[:, :, (h - h_size):h, 0:w_size],
x[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
output_list = []
sr_list = []
for i in range(0, 4, n_GPUs):
input_batch = torch.cat(input_list[i:(i + n_GPUs)], dim=0)
output_batch = model(input_batch)
output_list.extend(output_batch.chunk(n_GPUs, dim=0))
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
sr_batch = model(lr_batch)
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
else:
output_list = [
sr_list = [
chop_forward(patch, model, scale, shave, min_size, n_GPUs) \
for patch in input_list]
for patch in lr_list
]
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
......@@ -253,13 +246,13 @@ def chop_forward(x, model, scale, shave=10, min_size=80000, n_GPUs=1):
output = Variable(x.data.new(b, c, h, w), volatile=True)
output[:, :, 0:h_half, 0:w_half] \
= output_list[0][:, :, 0:h_half, 0:w_half]
= sr_list[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= output_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= output_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= output_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output
......@@ -284,20 +277,20 @@ def x8_forward(img, model, precision='single'):
return Variable(ret, volatile=True)
input_list = [img]
lr_list = [img]
for tf in 'vflip', 'hflip', 'transpose':
input_list.extend([_transform(t, tf) for t in input_list])
lr_list.extend([_transform(t, tf) for t in lr_list])
output_list = [model(aug) for aug in input_list]
for i in range(len(output_list)):
sr_list = [model(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
output_list[i] = _transform(output_list[i], 'transpose')
sr_list[i] = _transform(sr_list[i], 'transpose')
if i % 4 > 1:
output_list[i] = _transform(output_list[i], 'hflip')
sr_list[i] = _transform(sr_list[i], 'hflip')
if (i % 4) % 2 == 1:
output_list[i] = _transform(output_list[i], 'vflip')
sr_list[i] = _transform(sr_list[i], 'vflip')
output_cat = torch.cat(output_list, dim=0)
output_cat = torch.cat(sr_list, dim=0)
output = output_cat.mean(dim=0, keepdim=True)
#output = output_cat.median(dim=0, keepdim=True)[0]
......@@ -306,32 +299,31 @@ def x8_forward(img, model, precision='single'):
def quantize(img, rgb_range):
return img.mul(255 / rgb_range).clamp(0, 255).round().div(255)
def rgb2ycbcrT(rgb):
rgb = rgb.numpy().transpose(1, 2, 0)
yCbCr = sc.rgb2ycbcr(rgb) / 255
def calc_PSNR(sr, hr, set_name, scale):
'''
Here we assume normalized(0~1) and quantized arguments.
For Set5, Set14, B100, Urban100 dataset,
we measure PSNR on luminance channel only
'''
diff = (sr - hr).data
_, c, h, w = diff.size()
return torch.Tensor(yCbCr[:, :, 0])
def calc_PSNR(input, target, set_name, rgb_range, scale):
# We will evaluate these datasets in y channel only
test_Y = ['Set5', 'Set14', 'B100', 'Urban100']
_, c, h, w = input.size()
input = quantize(input.data[0], rgb_range)
target = quantize(target[:, :, 0:h, 0:w].data[0], rgb_range)
diff = input - target
if set_name in test_Y:
luminance_only = ['Set5', 'Set14', 'B100', 'Urban100']
if set_name in luminance_only:
shave = scale
if c > 1:
input_Y = rgb2ycbcrT(input.cpu())
target_Y = rgb2ycbcrT(target.cpu())
diff = (input_Y - target_Y).view(1, h, w)
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
diff = diff[:, shave:(h - shave), shave:(w - shave)]
mse = diff.pow(2).mean()
psnr = -10 * np.log10(mse)
valid = diff[:, :, shave:(h-shave), shave:(w-shave)]
mse = valid.pow(2).mean()
return psnr
return -10 * math.log10(mse)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment