import threading import random import torch import torch.multiprocessing as multiprocessing from torch.utils.data import DataLoader from torch.utils.data import SequentialSampler from torch.utils.data import RandomSampler from torch.utils.data import BatchSampler from torch.utils.data import _utils from torch.utils.data.dataloader import _DataLoaderIter from torch.utils.data._utils import collate from torch.utils.data._utils import signal_handling from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data._utils import ExceptionWrapper from torch.utils.data._utils import IS_WINDOWS from torch.utils.data._utils.worker import ManagerWatchdog from torch._six import queue def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: collate._use_shared_memory = True signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: assert done_event.is_set() return elif done_event.is_set(): continue idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples except KeyboardInterrupt: pass class _MSDataLoaderIter(_DataLoaderIter): def __init__(self, loader): self.dataset = loader.dataset self.scale = loader.scale self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.done_event = multiprocessing.Event() base_seed = torch.LongTensor(1).random_()[0] self.index_queues = [] self.workers = [] for i in range(self.num_workers): index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( target=_ms_loop, args=( self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, self.scale, base_seed + i, self.worker_init_fn, i ) ) w.daemon = True w.start() self.index_queues.append(index_queue) self.workers.append(w) if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=( self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event ) ) pin_memory_thread.daemon = True pin_memory_thread.start() self.pin_memory_thread = pin_memory_thread else: self.data_queue = self.worker_result_queue _utils.signal_handling._set_worker_pids( id(self), tuple(w.pid for w in self.workers) ) _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True for _ in range(2 * self.num_workers): self._put_indices() class MSDataLoader(DataLoader): def __init__(self, cfg, *args, **kwargs): super(MSDataLoader, self).__init__( *args, **kwargs, num_workers=cfg.n_threads ) self.scale = cfg.scale def __iter__(self): return _MSDataLoaderIter(self)