diff --git a/README.md b/README.md index fb94e23eeba2864fa0ce53d72e9e5a1fd2ee142c..09aacc50a6af05b3354423220587d4c4f2d4f661 100755 --- a/README.md +++ b/README.md @@ -1,13 +1,7 @@ -# [Challenges for NTIRE 2019 is open!](http://www.vision.ee.ethz.ch/ntire19/) - -**The challenge winners will be awarded at the CVPR 2019 Workshop.** - - # EDSR-PyTorch -**About PyTorch 1.0.0** - * We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches. - * ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument. +**About PyTorch 1.1.0** + * There have been minor changes with the 1.1.0 update. Now we support PyTorch 1.1.0 by default, and please use the legacy branch if you prefer older version.  diff --git a/src/dataloader.py b/src/dataloader.py index 16f4e2dc8442656c0daf0316872c051665700bdc..63257a3008530f66fa56a19383e9fe7265fd0a9d 100644 --- a/src/dataloader.py +++ b/src/dataloader.py @@ -1,42 +1,33 @@ -import sys import threading -import queue import random -import collections import torch import torch.multiprocessing as multiprocessing - -from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ - _remove_worker_pids, _error_if_any_worker_fails -from torch.utils.data.dataloader import DataLoader +from torch.utils.data import DataLoader +from torch.utils.data import SequentialSampler +from torch.utils.data import RandomSampler +from torch.utils.data import BatchSampler +from torch.utils.data import _utils from torch.utils.data.dataloader import _DataLoaderIter -from torch.utils.data.dataloader import ManagerWatchdog -from torch.utils.data.dataloader import _pin_memory_loop -from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL - -from torch.utils.data.dataloader import ExceptionWrapper -from torch.utils.data.dataloader import _use_shared_memory -from torch.utils.data.dataloader import numpy_type_map -from torch.utils.data.dataloader import default_collate -from torch.utils.data.dataloader import pin_memory_batch -from torch.utils.data.dataloader import _SIGCHLD_handler_set -from torch.utils.data.dataloader import _set_SIGCHLD_handler - -if sys.version_info[0] == 2: - import Queue as queue -else: - import queue + +from torch.utils.data._utils import collate +from torch.utils.data._utils import signal_handling +from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL +from torch.utils.data._utils import ExceptionWrapper +from torch.utils.data._utils import IS_WINDOWS +from torch.utils.data._utils.worker import ManagerWatchdog + +from torch._six import queue def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: - global _use_shared_memory - _use_shared_memory = True - _set_worker_signal_handlers() + collate._use_shared_memory = True + signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) + data_queue.cancel_join_thread() if init_fn is not None: @@ -55,6 +46,7 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se return elif done_event.is_set(): continue + idx, batch_indices = r try: idx_scale = 0 @@ -68,10 +60,13 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) + del samples + except KeyboardInterrupt: pass class _MSDataLoaderIter(_DataLoaderIter): + def __init__(self, loader): self.dataset = loader.dataset self.scale = loader.scale @@ -118,6 +113,7 @@ class _MSDataLoaderIter(_DataLoaderIter): i ) ) + w.daemon = True w.start() self.index_queues.append(index_queue) self.workers.append(w) @@ -125,7 +121,7 @@ class _MSDataLoaderIter(_DataLoaderIter): if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( - target=_pin_memory_loop, + target=_utils.pin_memory._pin_memory_loop, args=( self.worker_result_queue, self.data_queue, @@ -139,35 +135,24 @@ class _MSDataLoaderIter(_DataLoaderIter): else: self.data_queue = self.worker_result_queue - _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) - _set_SIGCHLD_handler() + _utils.signal_handling._set_worker_pids( + id(self), tuple(w.pid for w in self.workers) + ) + _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True for _ in range(2 * self.num_workers): self._put_indices() + class MSDataLoader(DataLoader): - def __init__( - self, args, dataset, batch_size=1, shuffle=False, - sampler=None, batch_sampler=None, - collate_fn=default_collate, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): + def __init__(self, cfg, *args, **kwargs): super(MSDataLoader, self).__init__( - dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=args.n_threads, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn + *args, **kwargs, num_workers=cfg.n_threads ) - - self.scale = args.scale + self.scale = cfg.scale def __iter__(self): return _MSDataLoaderIter(self) + diff --git a/src/trainer.py b/src/trainer.py index 40b15a0f6dcfd1f35dd053eefe7a5a890bedd3af..9d3a71c65d5dd3056a36e8522458d736a1186e56 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -26,7 +26,6 @@ class Trainer(): self.error_last = 1e8 def train(self): - self.optimizer.schedule() self.loss.step() epoch = self.optimizer.get_last_epoch() + 1 lr = self.optimizer.get_lr() @@ -68,11 +67,12 @@ class Trainer(): self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] + self.optimizer.schedule() def test(self): torch.set_grad_enabled(False) - epoch = self.optimizer.get_last_epoch() + 1 + epoch = self.optimizer.get_last_epoch() self.ckp.write_log('\nEvaluation:') self.ckp.add_log( torch.zeros(1, len(self.loader_test), len(self.scale))