Skip to content
Snippets Groups Projects
Commit bf5034d5 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

support PyTorch 1.2.0

    Check the update logs
parent 9a9d7d74
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
**About PyTorch 1.0.0** **About PyTorch 1.2.0**
* We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches. * Now the master branch supports PyTorch 1.2.0 by default.
* ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument. * Due to the serious version problem (especially torch.utils.data.dataloader), MDSR functions are temporarily disabled. If you have to train/evaluate the MDSR model, please use legacy branches.
# EDSR-PyTorch # EDSR-PyTorch
![](/figs/main.png) ![](/figs/main.png)
...@@ -20,7 +20,7 @@ If you find our work useful in your research or publication, please cite our wor ...@@ -20,7 +20,7 @@ If you find our work useful in your research or publication, please cite our wor
year = {2017} year = {2017}
} }
``` ```
We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images. We provide scripts for reproducing all the results from our paper. You can train your model from scratch, or use a pre-trained model to enlarge your images.
**Differences between Torch version** **Differences between Torch version**
* Codes are much more compact. (Removed all unnecessary parts.) * Codes are much more compact. (Removed all unnecessary parts.)
...@@ -47,7 +47,7 @@ cd EDSR-PyTorch ...@@ -47,7 +47,7 @@ cd EDSR-PyTorch
``` ```
## Quickstart (Demo) ## Quickstart (Demo)
You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files. You can test our super-resolution algorithm with your images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files.
Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute. Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
```bash ```bash
...@@ -123,17 +123,17 @@ sh demo.sh ...@@ -123,17 +123,17 @@ sh demo.sh
* Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.) * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.)
* Feb 21, 2018 * Feb 21, 2018
* Fixed the problem when loading pre-trained multi-gpu model. * Fixed the problem when loading pre-trained multi-GPU model.
* Added pre-trained scale 2 baseline model. * Added pre-trained scale 2 baseline model.
* This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models. * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to keep all the intermediate models.
* PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch. * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch.
* Feb 23, 2018 * Feb 23, 2018
* Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version. * Now PyTorch 0.3.1 is a default. Use legacy/0.3.0 branch if you use the old version.
* With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution. * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution.
* New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.) * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.)
* With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.) * With ``--ext bin``, this code will automatically generate and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.)
* If you cannot make the binary pack, just use the default setting (``--ext img``). * If you cannot make the binary pack, use the default setting (``--ext img``).
* Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match. * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match.
* Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.) * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.)
...@@ -146,23 +146,23 @@ sh demo.sh ...@@ -146,23 +146,23 @@ sh demo.sh
* Mar 11, 2018 * Mar 11, 2018
* Fixed some typos in the code and script. * Fixed some typos in the code and script.
* Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only. * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.
* Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected. * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly the same with that of Torch7 version, it will work as you expected.
* Mar 20, 2018 * Mar 20, 2018
* Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time. * Use ``--ext sep-reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
* Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images. * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
* Changed the behavior of skip_batch. * Changed the behavior of skip_batch.
* Mar 29, 2018 * Mar 29, 2018
* We now provide all models from our paper. * We now provide all models from our paper.
* We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble. * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in the original low-resolution image. Please use it if you have any trouble.
* ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before. * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.
* Some codes and script are re-written. * Some codes and script are re-written.
* Apr 9, 2018 * Apr 9, 2018
* VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet. * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet.
* Many codes are refactored. If there exists a bug, please report it. * Many codes are refactored. If there exists a bug, please report it.
* [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L. * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. The default setting is D-DBPN-L.
* Apr 26, 2018 * Apr 26, 2018
* Compatible with PyTorch 0.4.0 * Compatible with PyTorch 0.4.0
...@@ -171,9 +171,12 @@ sh demo.sh ...@@ -171,9 +171,12 @@ sh demo.sh
* July 22, 2018 * July 22, 2018
* Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.
* Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid using ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).
* Oct 18, 2018 * Oct 18, 2018
* with ``--pre_train download``, pretrained models will be automatically downloaded from server. * with ``--pre_train download``, pretrained models will be automatically downloaded from the server.
* Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``. * Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``.
* About PyTorch 1.0.0
* We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches.
* ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument.
from importlib import import_module from importlib import import_module
from dataloader import MSDataLoader #from dataloader import MSDataLoader
from torch.utils.data import dataloader
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
# This is a simple wrapper function for ConcatDataset # This is a simple wrapper function for ConcatDataset
...@@ -22,12 +23,12 @@ class Data: ...@@ -22,12 +23,12 @@ class Data:
m = import_module('data.' + module_name.lower()) m = import_module('data.' + module_name.lower())
datasets.append(getattr(m, module_name)(args, name=d)) datasets.append(getattr(m, module_name)(args, name=d))
self.loader_train = MSDataLoader( self.loader_train = dataloader.DataLoader(
args,
MyConcatDataset(datasets), MyConcatDataset(datasets),
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
pin_memory=not args.cpu pin_memory=not args.cpu,
num_workers=args.n_threads,
) )
self.loader_test = [] self.loader_test = []
...@@ -40,11 +41,12 @@ class Data: ...@@ -40,11 +41,12 @@ class Data:
m = import_module('data.' + module_name.lower()) m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d) testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test.append(MSDataLoader( self.loader_test.append(
args, dataloader.DataLoader(
testset, testset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
pin_memory=not args.cpu pin_memory=not args.cpu,
)) num_workers=args.n_threads,
)
)
import sys
import threading
import queue
import random
import collections
import torch
import torch.multiprocessing as multiprocessing
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
_remove_worker_pids, _error_if_any_worker_fails
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data.dataloader import ManagerWatchdog
from torch.utils.data.dataloader import _pin_memory_loop
from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory
from torch.utils.data.dataloader import numpy_type_map
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch
from torch.utils.data.dataloader import _SIGCHLD_handler_set
from torch.utils.data.dataloader import _set_SIGCHLD_handler
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
continue
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
except KeyboardInterrupt:
pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event()
base_seed = torch.LongTensor(1).random_()[0]
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
index_queue,
self.worker_result_queue,
self.done_event,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
torch.cuda.current_device(),
self.done_event
)
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
for _ in range(2 * self.num_workers):
self._put_indices()
class MSDataLoader(DataLoader):
def __init__(
self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(MSDataLoader, self).__init__(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=args.n_threads,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn
)
self.scale = args.scale
def __iter__(self):
return _MSDataLoaderIter(self)
# EDSR baseline model (x2) + JPEG augmentation # EDSR baseline model (x2) + JPEG augmentation
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset python main.py --model EDSR --scale 2 --patch_size 96 --save test_edsr_baseline_x2 --reset
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
# EDSR baseline model (x3) - from EDSR baseline model (x2) # EDSR baseline model (x3) - from EDSR baseline model (x2)
......
...@@ -26,7 +26,6 @@ class Trainer(): ...@@ -26,7 +26,6 @@ class Trainer():
self.error_last = 1e8 self.error_last = 1e8
def train(self): def train(self):
self.optimizer.schedule()
self.loss.step() self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1 epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr() lr = self.optimizer.get_lr()
...@@ -38,13 +37,15 @@ class Trainer(): ...@@ -38,13 +37,15 @@ class Trainer():
self.model.train() self.model.train()
timer_data, timer_model = utility.timer(), utility.timer() timer_data, timer_model = utility.timer(), utility.timer()
for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train): # TEMP
self.loader_train.dataset.set_scale(0)
for batch, (lr, hr, _,) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr) lr, hr = self.prepare(lr, hr)
timer_data.hold() timer_data.hold()
timer_model.tic() timer_model.tic()
self.optimizer.zero_grad() self.optimizer.zero_grad()
sr = self.model(lr, idx_scale) sr = self.model(lr, 0)
loss = self.loss(sr, hr) loss = self.loss(sr, hr)
loss.backward() loss.backward()
if self.args.gclip > 0: if self.args.gclip > 0:
...@@ -68,11 +69,12 @@ class Trainer(): ...@@ -68,11 +69,12 @@ class Trainer():
self.loss.end_log(len(self.loader_train)) self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1] self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self): def test(self):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch() + 1 epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:') self.ckp.write_log('\nEvaluation:')
self.ckp.add_log( self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale)) torch.zeros(1, len(self.loader_test), len(self.scale))
...@@ -84,7 +86,7 @@ class Trainer(): ...@@ -84,7 +86,7 @@ class Trainer():
for idx_data, d in enumerate(self.loader_test): for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale): for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale) d.dataset.set_scale(idx_scale)
for lr, hr, filename, _ in tqdm(d, ncols=80): for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr) lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale) sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range) sr = utility.quantize(sr, self.args.rgb_range)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment