Skip to content
Snippets Groups Projects
Commit 565d6fea authored by Sanghyun Son's avatar Sanghyun Son
Browse files

Now support PyTorch 1.1.0 by default

parent 518c0fa3
Branches master
No related tags found
1 merge request!1Jan 09, 2018 updates
# [Challenges for NTIRE 2019 is open!](http://www.vision.ee.ethz.ch/ntire19/)
**The challenge winners will be awarded at the CVPR 2019 Workshop.**
![](/figs/ntire2019.png)
# EDSR-PyTorch
**About PyTorch 1.0.0**
* We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches.
* ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument.
**About PyTorch 1.1.0**
* There have been minor changes with the 1.1.0 update. Now we support PyTorch 1.1.0 by default, and please use the legacy branch if you prefer older version.
![](/figs/main.png)
......
import sys
import threading
import queue
import random
import collections
import torch
import torch.multiprocessing as multiprocessing
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
_remove_worker_pids, _error_if_any_worker_fails
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data.dataloader import ManagerWatchdog
from torch.utils.data.dataloader import _pin_memory_loop
from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory
from torch.utils.data.dataloader import numpy_type_map
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch
from torch.utils.data.dataloader import _SIGCHLD_handler_set
from torch.utils.data.dataloader import _set_SIGCHLD_handler
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
from torch.utils.data._utils import collate
from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data._utils.worker import ManagerWatchdog
from torch._six import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()
collate._use_shared_memory = True
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
......@@ -55,6 +46,7 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se
return
elif done_event.is_set():
continue
idx, batch_indices = r
try:
idx_scale = 0
......@@ -68,10 +60,13 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples
except KeyboardInterrupt:
pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
......@@ -118,6 +113,7 @@ class _MSDataLoaderIter(_DataLoaderIter):
i
)
)
w.daemon = True
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
......@@ -125,7 +121,7 @@ class _MSDataLoaderIter(_DataLoaderIter):
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
target=_utils.pin_memory._pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
......@@ -139,35 +135,24 @@ class _MSDataLoaderIter(_DataLoaderIter):
else:
self.data_queue = self.worker_result_queue
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
_utils.signal_handling._set_worker_pids(
id(self), tuple(w.pid for w in self.workers)
)
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True
for _ in range(2 * self.num_workers):
self._put_indices()
class MSDataLoader(DataLoader):
def __init__(
self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
def __init__(self, cfg, *args, **kwargs):
super(MSDataLoader, self).__init__(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=args.n_threads,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn
*args, **kwargs, num_workers=cfg.n_threads
)
self.scale = args.scale
self.scale = cfg.scale
def __iter__(self):
return _MSDataLoaderIter(self)
......@@ -26,7 +26,6 @@ class Trainer():
self.error_last = 1e8
def train(self):
self.optimizer.schedule()
self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr()
......@@ -68,11 +67,12 @@ class Trainer():
self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch() + 1
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment