diff --git a/README.md b/README.md index c93e0ee4e82a9a7dda03b67090756c7bde14e8ce..c567af77ddd2dc01e2d20c5998e4a788895f328e 100755 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -**About PyTorch 1.0.0** - * We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches. - * ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument. +**About PyTorch 1.2.0** + * Now the master branch supports PyTorch 1.2.0 by default. + * Due to the serious version problem (especially torch.utils.data.dataloader), MDSR functions are temporarily disabled. If you have to train/evaluate the MDSR model, please use legacy branches. # EDSR-PyTorch  @@ -20,7 +20,7 @@ If you find our work useful in your research or publication, please cite our wor year = {2017} } ``` -We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images. +We provide scripts for reproducing all the results from our paper. You can train your model from scratch, or use a pre-trained model to enlarge your images. **Differences between Torch version** * Codes are much more compact. (Removed all unnecessary parts.) @@ -46,8 +46,8 @@ git clone https://github.com/thstkdgus35/EDSR-PyTorch cd EDSR-PyTorch ``` -## Quick start (Demo) -You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files. +## Quickstart (Demo) +You can test our super-resolution algorithm with your images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files. Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute. ```bash @@ -123,17 +123,17 @@ sh demo.sh * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.) * Feb 21, 2018 - * Fixed the problem when loading pre-trained multi-gpu model. + * Fixed the problem when loading pre-trained multi-GPU model. * Added pre-trained scale 2 baseline model. - * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models. + * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to keep all the intermediate models. * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch. * Feb 23, 2018 - * Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version. + * Now PyTorch 0.3.1 is a default. Use legacy/0.3.0 branch if you use the old version. * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution. * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.) - * With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.) - * If you cannot make the binary pack, just use the default setting (``--ext img``). + * With ``--ext bin``, this code will automatically generate and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.) + * If you cannot make the binary pack, use the default setting (``--ext img``). * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match. * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.) @@ -146,23 +146,23 @@ sh demo.sh * Mar 11, 2018 * Fixed some typos in the code and script. * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only. - * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected. + * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly the same with that of Torch7 version, it will work as you expected. * Mar 20, 2018 - * Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time. + * Use ``--ext sep-reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time. * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images. * Changed the behavior of skip_batch. * Mar 29, 2018 * We now provide all models from our paper. - * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble. + * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in the original low-resolution image. Please use it if you have any trouble. * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before. * Some codes and script are re-written. * Apr 9, 2018 * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet. * Many codes are refactored. If there exists a bug, please report it. - * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L. + * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. The default setting is D-DBPN-L. * Apr 26, 2018 * Compatible with PyTorch 0.4.0 @@ -171,9 +171,12 @@ sh demo.sh * July 22, 2018 * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. - * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). + * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid using ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). * Oct 18, 2018 - * with ``--pre_train download``, pretrained models will be automatically downloaded from server. + * with ``--pre_train download``, pretrained models will be automatically downloaded from the server. * Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``. +* About PyTorch 1.0.0 + * We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches. + * ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument. diff --git a/src/data/__init__.py b/src/data/__init__.py index 827320054c221b7f7e4b25baec2e2a57259807f0..6dc2cdb4a1652976634a657c9d4106cb58a233be 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,5 +1,6 @@ from importlib import import_module -from dataloader import MSDataLoader +#from dataloader import MSDataLoader +from torch.utils.data import dataloader from torch.utils.data import ConcatDataset # This is a simple wrapper function for ConcatDataset @@ -22,12 +23,12 @@ class Data: m = import_module('data.' + module_name.lower()) datasets.append(getattr(m, module_name)(args, name=d)) - self.loader_train = MSDataLoader( - args, + self.loader_train = dataloader.DataLoader( MyConcatDataset(datasets), batch_size=args.batch_size, shuffle=True, - pin_memory=not args.cpu + pin_memory=not args.cpu, + num_workers=args.n_threads, ) self.loader_test = [] @@ -40,11 +41,12 @@ class Data: m = import_module('data.' + module_name.lower()) testset = getattr(m, module_name)(args, train=False, name=d) - self.loader_test.append(MSDataLoader( - args, - testset, - batch_size=1, - shuffle=False, - pin_memory=not args.cpu - )) - + self.loader_test.append( + dataloader.DataLoader( + testset, + batch_size=1, + shuffle=False, + pin_memory=not args.cpu, + num_workers=args.n_threads, + ) + ) diff --git a/src/dataloader.py b/src/dataloader.py deleted file mode 100644 index 16f4e2dc8442656c0daf0316872c051665700bdc..0000000000000000000000000000000000000000 --- a/src/dataloader.py +++ /dev/null @@ -1,173 +0,0 @@ -import sys -import threading -import queue -import random -import collections - -import torch -import torch.multiprocessing as multiprocessing - -from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ - _remove_worker_pids, _error_if_any_worker_fails -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataloader import _DataLoaderIter -from torch.utils.data.dataloader import ManagerWatchdog -from torch.utils.data.dataloader import _pin_memory_loop -from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL - -from torch.utils.data.dataloader import ExceptionWrapper -from torch.utils.data.dataloader import _use_shared_memory -from torch.utils.data.dataloader import numpy_type_map -from torch.utils.data.dataloader import default_collate -from torch.utils.data.dataloader import pin_memory_batch -from torch.utils.data.dataloader import _SIGCHLD_handler_set -from torch.utils.data.dataloader import _set_SIGCHLD_handler - -if sys.version_info[0] == 2: - import Queue as queue -else: - import queue - -def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): - try: - global _use_shared_memory - _use_shared_memory = True - _set_worker_signal_handlers() - - torch.set_num_threads(1) - random.seed(seed) - torch.manual_seed(seed) - data_queue.cancel_join_thread() - - if init_fn is not None: - init_fn(worker_id) - - watchdog = ManagerWatchdog() - - while watchdog.is_alive(): - try: - r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue - - if r is None: - assert done_event.is_set() - return - elif done_event.is_set(): - continue - idx, batch_indices = r - try: - idx_scale = 0 - if len(scale) > 1 and dataset.train: - idx_scale = random.randrange(0, len(scale)) - dataset.set_scale(idx_scale) - - samples = collate_fn([dataset[i] for i in batch_indices]) - samples.append(idx_scale) - except Exception: - data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) - else: - data_queue.put((idx, samples)) - except KeyboardInterrupt: - pass - -class _MSDataLoaderIter(_DataLoaderIter): - def __init__(self, loader): - self.dataset = loader.dataset - self.scale = loader.scale - self.collate_fn = loader.collate_fn - self.batch_sampler = loader.batch_sampler - self.num_workers = loader.num_workers - self.pin_memory = loader.pin_memory and torch.cuda.is_available() - self.timeout = loader.timeout - - self.sample_iter = iter(self.batch_sampler) - - base_seed = torch.LongTensor(1).random_().item() - - if self.num_workers > 0: - self.worker_init_fn = loader.worker_init_fn - self.worker_queue_idx = 0 - self.worker_result_queue = multiprocessing.Queue() - self.batches_outstanding = 0 - self.worker_pids_set = False - self.shutdown = False - self.send_idx = 0 - self.rcvd_idx = 0 - self.reorder_dict = {} - self.done_event = multiprocessing.Event() - - base_seed = torch.LongTensor(1).random_()[0] - - self.index_queues = [] - self.workers = [] - for i in range(self.num_workers): - index_queue = multiprocessing.Queue() - index_queue.cancel_join_thread() - w = multiprocessing.Process( - target=_ms_loop, - args=( - self.dataset, - index_queue, - self.worker_result_queue, - self.done_event, - self.collate_fn, - self.scale, - base_seed + i, - self.worker_init_fn, - i - ) - ) - w.start() - self.index_queues.append(index_queue) - self.workers.append(w) - - if self.pin_memory: - self.data_queue = queue.Queue() - pin_memory_thread = threading.Thread( - target=_pin_memory_loop, - args=( - self.worker_result_queue, - self.data_queue, - torch.cuda.current_device(), - self.done_event - ) - ) - pin_memory_thread.daemon = True - pin_memory_thread.start() - self.pin_memory_thread = pin_memory_thread - else: - self.data_queue = self.worker_result_queue - - _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) - _set_SIGCHLD_handler() - self.worker_pids_set = True - - for _ in range(2 * self.num_workers): - self._put_indices() - -class MSDataLoader(DataLoader): - def __init__( - self, args, dataset, batch_size=1, shuffle=False, - sampler=None, batch_sampler=None, - collate_fn=default_collate, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): - - super(MSDataLoader, self).__init__( - dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=args.n_threads, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn - ) - - self.scale = args.scale - - def __iter__(self): - return _MSDataLoaderIter(self) diff --git a/src/demo.sh b/src/demo.sh index 46565463d65ad44f1df89d260c27f57411010905..7b059cdddfcf6d99bc42e477a67e6e8baa717621 100644 --- a/src/demo.sh +++ b/src/demo.sh @@ -1,5 +1,5 @@ # EDSR baseline model (x2) + JPEG augmentation -#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset +python main.py --model EDSR --scale 2 --patch_size 96 --save test_edsr_baseline_x2 --reset #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 # EDSR baseline model (x3) - from EDSR baseline model (x2) diff --git a/src/trainer.py b/src/trainer.py index 40b15a0f6dcfd1f35dd053eefe7a5a890bedd3af..849ae5c7e47aa0dd18e369c64774edc868d06e2d 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -26,7 +26,6 @@ class Trainer(): self.error_last = 1e8 def train(self): - self.optimizer.schedule() self.loss.step() epoch = self.optimizer.get_last_epoch() + 1 lr = self.optimizer.get_lr() @@ -38,13 +37,15 @@ class Trainer(): self.model.train() timer_data, timer_model = utility.timer(), utility.timer() - for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train): + # TEMP + self.loader_train.dataset.set_scale(0) + for batch, (lr, hr, _,) in enumerate(self.loader_train): lr, hr = self.prepare(lr, hr) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() - sr = self.model(lr, idx_scale) + sr = self.model(lr, 0) loss = self.loss(sr, hr) loss.backward() if self.args.gclip > 0: @@ -68,11 +69,12 @@ class Trainer(): self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] + self.optimizer.schedule() def test(self): torch.set_grad_enabled(False) - epoch = self.optimizer.get_last_epoch() + 1 + epoch = self.optimizer.get_last_epoch() self.ckp.write_log('\nEvaluation:') self.ckp.add_log( torch.zeros(1, len(self.loader_test), len(self.scale)) @@ -84,7 +86,7 @@ class Trainer(): for idx_data, d in enumerate(self.loader_test): for idx_scale, scale in enumerate(self.scale): d.dataset.set_scale(idx_scale) - for lr, hr, filename, _ in tqdm(d, ncols=80): + for lr, hr, filename in tqdm(d, ncols=80): lr, hr = self.prepare(lr, hr) sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range)