Skip to content
Snippets Groups Projects
Commit 9a9d7d74 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

support 1.0.0

parent a90b54d1
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
**About PyTorch 1.0.0**
* We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches.
* ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument.
# EDSR-PyTorch
![](/figs/main.png)
......@@ -27,19 +31,13 @@ We provide scripts for reproducing all the results from our paper. You can train
## Dependencies
* Python 3.6
* PyTorch >= 0.4.0
* PyTorch >= 1.0.0
* numpy
* skimage
* **imageio**
* matplotlib
* tqdm
* cv2 >= 3.xx (Only if you use video input/output)
**Recent updates**
* Oct 18, 2018
* with ``--pre_train download``, pretrained models will be automatically downloaded from server.
* Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``.
* cv2 >= 3.xx (Only if you want to use video input/output)
## Code
Clone this repository into any place you want.
......@@ -175,3 +173,7 @@ sh demo.sh
* Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.
* Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).
* Oct 18, 2018
* with ``--pre_train download``, pretrained models will be automatically downloaded from server.
* Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``.
......@@ -28,18 +28,6 @@ class SRData(data.Dataset):
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('bin') >= 0:
# Binary files are stored in 'bin' folder
# If the binary file exists, load it. If not, make it.
list_hr, list_lr = self._scan()
self.images_hr = self._check_and_load(
args.ext, list_hr, self._name_hrbin()
)
self.images_lr = [
self._check_and_load(args.ext, l, self._name_lrbin(s)) \
for s, l in zip(self.scale, list_lr)
]
else:
if args.ext.find('img') >= 0 or benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
......@@ -61,19 +49,13 @@ class SRData(data.Dataset):
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(
args.ext, [h], b, verbose=True, load=False
)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(
args.ext, [l], b, verbose=True, load=False
)
self._check_and_load(args.ext, l, b, verbose=True)
if train:
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
......@@ -106,41 +88,12 @@ class SRData(data.Dataset):
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.png')
def _name_hrbin(self):
return os.path.join(
self.apath,
'bin',
'{}_bin_HR.pt'.format(self.split)
)
def _name_lrbin(self, scale):
return os.path.join(
self.apath,
'bin',
'{}_bin_LR_X{}.pt'.format(self.split, scale)
)
def _check_and_load(self, ext, l, f, verbose=True, load=True):
if os.path.isfile(f) and ext.find('reset') < 0:
if load:
if verbose: print('Loading {}...'.format(f))
with open(f, 'rb') as _f: ret = pickle.load(_f)
return ret
else:
return None
else:
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
if ext.find('reset') >= 0:
print('Making a new binary: {}'.format(f))
else:
print('{} does not exist. Now making binary...'.format(f))
b = [{
'name': os.path.splitext(os.path.basename(_l))[0],
'image': imageio.imread(_l)
} for _l in l]
with open(f, 'wb') as _f: pickle.dump(b, _f)
return b
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
lr, hr, filename = self._load_file(idx)
......@@ -167,18 +120,15 @@ class SRData(data.Dataset):
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
if self.args.ext.find('bin') >= 0:
filename = f_hr['name']
hr = f_hr['image']
lr = f_lr['image']
else:
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image']
with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image']
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
......
......@@ -11,10 +11,12 @@ from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
_remove_worker_pids, _error_if_any_worker_fails
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data.dataloader import ManagerWatchdog
from torch.utils.data.dataloader import _pin_memory_loop
from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory
from torch.utils.data.dataloader import _worker_manager_loop
from torch.utils.data.dataloader import numpy_type_map
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch
......@@ -26,17 +28,33 @@ if sys.version_info[0] == 2:
else:
import queue
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
while True:
r = index_queue.get()
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
break
assert done_event.is_set()
return
elif done_event.is_set():
continue
idx, batch_indices = r
try:
idx_scale = 0
......@@ -46,11 +64,12 @@ def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn,
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
except KeyboardInterrupt:
pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
......@@ -61,32 +80,37 @@ class _MSDataLoaderIter(_DataLoaderIter):
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queues = [
multiprocessing.Queue() for _ in range(self.num_workers)
]
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.SimpleQueue()
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event()
base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
self.index_queues[i],
index_queue,
self.worker_result_queue,
self.done_event,
self.collate_fn,
self.scale,
base_seed + i,
......@@ -94,33 +118,31 @@ class _MSDataLoaderIter(_DataLoaderIter):
i
)
)
for i in range(self.num_workers)]
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
maybe_device_id))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
torch.cuda.current_device(),
self.done_event
)
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
......@@ -132,11 +154,18 @@ class MSDataLoader(DataLoader):
timeout=0, worker_init_fn=None):
super(MSDataLoader, self).__init__(
dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, batch_sampler=batch_sampler,
num_workers=args.n_threads, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=args.n_threads,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn
)
self.scale = args.scale
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment