diff --git a/README.md b/README.md index 48dae5cfcb8978444b6dd5ca5793e2edfb7e6ddd..c93e0ee4e82a9a7dda03b67090756c7bde14e8ce 100755 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +**About PyTorch 1.0.0** + * We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches. + * ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument. + # EDSR-PyTorch  @@ -27,19 +31,13 @@ We provide scripts for reproducing all the results from our paper. You can train ## Dependencies * Python 3.6 -* PyTorch >= 0.4.0 +* PyTorch >= 1.0.0 * numpy * skimage * **imageio** * matplotlib * tqdm -* cv2 >= 3.xx (Only if you use video input/output) - -**Recent updates** - -* Oct 18, 2018 - * with ``--pre_train download``, pretrained models will be automatically downloaded from server. - * Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``. +* cv2 >= 3.xx (Only if you want to use video input/output) ## Code Clone this repository into any place you want. @@ -175,3 +173,7 @@ sh demo.sh * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). +* Oct 18, 2018 + * with ``--pre_train download``, pretrained models will be automatically downloaded from server. + * Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``. + diff --git a/src/data/srdata.py b/src/data/srdata.py index 5dcf99d2932d2d987e54663bba9479645ee2f924..b9109aad6d357da941baa8fe2529c43d1d769b52 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -28,52 +28,34 @@ class SRData(data.Dataset): os.makedirs(path_bin, exist_ok=True) list_hr, list_lr = self._scan() - if args.ext.find('bin') >= 0: - # Binary files are stored in 'bin' folder - # If the binary file exists, load it. If not, make it. - list_hr, list_lr = self._scan() - self.images_hr = self._check_and_load( - args.ext, list_hr, self._name_hrbin() + if args.ext.find('img') >= 0 or benchmark: + self.images_hr, self.images_lr = list_hr, list_lr + elif args.ext.find('sep') >= 0: + os.makedirs( + self.dir_hr.replace(self.apath, path_bin), + exist_ok=True ) - self.images_lr = [ - self._check_and_load(args.ext, l, self._name_lrbin(s)) \ - for s, l in zip(self.scale, list_lr) - ] - else: - if args.ext.find('img') >= 0 or benchmark: - self.images_hr, self.images_lr = list_hr, list_lr - elif args.ext.find('sep') >= 0: + for s in self.scale: os.makedirs( - self.dir_hr.replace(self.apath, path_bin), + os.path.join( + self.dir_lr.replace(self.apath, path_bin), + 'X{}'.format(s) + ), exist_ok=True ) - for s in self.scale: - os.makedirs( - os.path.join( - self.dir_lr.replace(self.apath, path_bin), - 'X{}'.format(s) - ), - exist_ok=True - ) - - self.images_hr, self.images_lr = [], [[] for _ in self.scale] - for h in list_hr: - b = h.replace(self.apath, path_bin) - b = b.replace(self.ext[0], '.pt') - self.images_hr.append(b) - self._check_and_load( - args.ext, [h], b, verbose=True, load=False - ) - - for i, ll in enumerate(list_lr): - for l in ll: - b = l.replace(self.apath, path_bin) - b = b.replace(self.ext[1], '.pt') - self.images_lr[i].append(b) - self._check_and_load( - args.ext, [l], b, verbose=True, load=False - ) - + + self.images_hr, self.images_lr = [], [[] for _ in self.scale] + for h in list_hr: + b = h.replace(self.apath, path_bin) + b = b.replace(self.ext[0], '.pt') + self.images_hr.append(b) + self._check_and_load(args.ext, h, b, verbose=True) + for i, ll in enumerate(list_lr): + for l in ll: + b = l.replace(self.apath, path_bin) + b = b.replace(self.ext[1], '.pt') + self.images_lr[i].append(b) + self._check_and_load(args.ext, l, b, verbose=True) if train: n_patches = args.batch_size * args.test_every n_images = len(args.data_train) * len(self.images_hr) @@ -106,41 +88,12 @@ class SRData(data.Dataset): if self.input_large: self.dir_lr += 'L' self.ext = ('.png', '.png') - def _name_hrbin(self): - return os.path.join( - self.apath, - 'bin', - '{}_bin_HR.pt'.format(self.split) - ) - - def _name_lrbin(self, scale): - return os.path.join( - self.apath, - 'bin', - '{}_bin_LR_X{}.pt'.format(self.split, scale) - ) - - def _check_and_load(self, ext, l, f, verbose=True, load=True): - if os.path.isfile(f) and ext.find('reset') < 0: - if load: - if verbose: print('Loading {}...'.format(f)) - with open(f, 'rb') as _f: ret = pickle.load(_f) - return ret - else: - return None - else: + def _check_and_load(self, ext, img, f, verbose=True): + if not os.path.isfile(f) or ext.find('reset') >= 0: if verbose: - if ext.find('reset') >= 0: - print('Making a new binary: {}'.format(f)) - else: - print('{} does not exist. Now making binary...'.format(f)) - b = [{ - 'name': os.path.splitext(os.path.basename(_l))[0], - 'image': imageio.imread(_l) - } for _l in l] - with open(f, 'wb') as _f: pickle.dump(b, _f) - - return b + print('Making a binary: {}'.format(f)) + with open(f, 'wb') as _f: + pickle.dump(imageio.imread(img), _f) def __getitem__(self, idx): lr, hr, filename = self._load_file(idx) @@ -167,18 +120,15 @@ class SRData(data.Dataset): f_hr = self.images_hr[idx] f_lr = self.images_lr[self.idx_scale][idx] - if self.args.ext.find('bin') >= 0: - filename = f_hr['name'] - hr = f_hr['image'] - lr = f_lr['image'] - else: - filename, _ = os.path.splitext(os.path.basename(f_hr)) - if self.args.ext == 'img' or self.benchmark: - hr = imageio.imread(f_hr) - lr = imageio.imread(f_lr) - elif self.args.ext.find('sep') >= 0: - with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] - with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] + filename, _ = os.path.splitext(os.path.basename(f_hr)) + if self.args.ext == 'img' or self.benchmark: + hr = imageio.imread(f_hr) + lr = imageio.imread(f_lr) + elif self.args.ext.find('sep') >= 0: + with open(f_hr, 'rb') as _f: + hr = pickle.load(_f) + with open(f_lr, 'rb') as _f: + lr = pickle.load(_f) return lr, hr, filename diff --git a/src/dataloader.py b/src/dataloader.py index b5405d8124e1fc89bb0238e4ba72bcba5894cd59..16f4e2dc8442656c0daf0316872c051665700bdc 100644 --- a/src/dataloader.py +++ b/src/dataloader.py @@ -11,10 +11,12 @@ from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ _remove_worker_pids, _error_if_any_worker_fails from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import _DataLoaderIter +from torch.utils.data.dataloader import ManagerWatchdog +from torch.utils.data.dataloader import _pin_memory_loop +from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataloader import ExceptionWrapper from torch.utils.data.dataloader import _use_shared_memory -from torch.utils.data.dataloader import _worker_manager_loop from torch.utils.data.dataloader import numpy_type_map from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import pin_memory_batch @@ -26,31 +28,48 @@ if sys.version_info[0] == 2: else: import queue -def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): - global _use_shared_memory - _use_shared_memory = True - _set_worker_signal_handlers() - - torch.set_num_threads(1) - torch.manual_seed(seed) - while True: - r = index_queue.get() - if r is None: - break - idx, batch_indices = r - try: - idx_scale = 0 - if len(scale) > 1 and dataset.train: - idx_scale = random.randrange(0, len(scale)) - dataset.set_scale(idx_scale) - - samples = collate_fn([dataset[i] for i in batch_indices]) - samples.append(idx_scale) - - except Exception: - data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) - else: - data_queue.put((idx, samples)) +def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): + try: + global _use_shared_memory + _use_shared_memory = True + _set_worker_signal_handlers() + + torch.set_num_threads(1) + random.seed(seed) + torch.manual_seed(seed) + data_queue.cancel_join_thread() + + if init_fn is not None: + init_fn(worker_id) + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + + if r is None: + assert done_event.is_set() + return + elif done_event.is_set(): + continue + idx, batch_indices = r + try: + idx_scale = 0 + if len(scale) > 1 and dataset.train: + idx_scale = random.randrange(0, len(scale)) + dataset.set_scale(idx_scale) + + samples = collate_fn([dataset[i] for i in batch_indices]) + samples.append(idx_scale) + except Exception: + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + data_queue.put((idx, samples)) + except KeyboardInterrupt: + pass class _MSDataLoaderIter(_DataLoaderIter): def __init__(self, loader): @@ -61,32 +80,37 @@ class _MSDataLoaderIter(_DataLoaderIter): self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout - self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) + base_seed = torch.LongTensor(1).random_().item() + if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn - self.index_queues = [ - multiprocessing.Queue() for _ in range(self.num_workers) - ] self.worker_queue_idx = 0 - self.worker_result_queue = multiprocessing.SimpleQueue() + self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} + self.done_event = multiprocessing.Event() base_seed = torch.LongTensor(1).random_()[0] - self.workers = [ - multiprocessing.Process( + + self.index_queues = [] + self.workers = [] + for i in range(self.num_workers): + index_queue = multiprocessing.Queue() + index_queue.cancel_join_thread() + w = multiprocessing.Process( target=_ms_loop, args=( self.dataset, - self.index_queues[i], + index_queue, self.worker_result_queue, + self.done_event, self.collate_fn, self.scale, base_seed + i, @@ -94,33 +118,31 @@ class _MSDataLoaderIter(_DataLoaderIter): i ) ) - for i in range(self.num_workers)] + w.start() + self.index_queues.append(index_queue) + self.workers.append(w) - if self.pin_memory or self.timeout > 0: + if self.pin_memory: self.data_queue = queue.Queue() - if self.pin_memory: - maybe_device_id = torch.cuda.current_device() - else: - # do not initialize cuda context if not necessary - maybe_device_id = None - self.worker_manager_thread = threading.Thread( - target=_worker_manager_loop, - args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, - maybe_device_id)) - self.worker_manager_thread.daemon = True - self.worker_manager_thread.start() + pin_memory_thread = threading.Thread( + target=_pin_memory_loop, + args=( + self.worker_result_queue, + self.data_queue, + torch.cuda.current_device(), + self.done_event + ) + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + self.pin_memory_thread = pin_memory_thread else: self.data_queue = self.worker_result_queue - for w in self.workers: - w.daemon = True # ensure that the worker exits on process exit - w.start() - _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True - # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() @@ -132,11 +154,18 @@ class MSDataLoader(DataLoader): timeout=0, worker_init_fn=None): super(MSDataLoader, self).__init__( - dataset, batch_size=batch_size, shuffle=shuffle, - sampler=sampler, batch_sampler=batch_sampler, - num_workers=args.n_threads, collate_fn=collate_fn, - pin_memory=pin_memory, drop_last=drop_last, - timeout=timeout, worker_init_fn=worker_init_fn) + dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=args.n_threads, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, + worker_init_fn=worker_init_fn + ) self.scale = args.scale