Skip to content
Snippets Groups Projects
Commit 565d6fea authored by Sanghyun Son's avatar Sanghyun Son
Browse files

Now support PyTorch 1.1.0 by default

parent 518c0fa3
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
# [Challenges for NTIRE 2019 is open!](http://www.vision.ee.ethz.ch/ntire19/)
**The challenge winners will be awarded at the CVPR 2019 Workshop.**
![](/figs/ntire2019.png)
# EDSR-PyTorch # EDSR-PyTorch
**About PyTorch 1.0.0** **About PyTorch 1.1.0**
* We support PyTorch 1.0.0. If you prefer the previous versions of PyTorch, use legacy branches. * There have been minor changes with the 1.1.0 update. Now we support PyTorch 1.1.0 by default, and please use the legacy branch if you prefer older version.
* ``--ext bin`` is not supported. Also, please erase your bin files with ``--ext sep-reset``. Once you successfully build those bin files, you can remove ``-reset`` from the argument.
![](/figs/main.png) ![](/figs/main.png)
......
import sys
import threading import threading
import queue
import random import random
import collections
import torch import torch
import torch.multiprocessing as multiprocessing import torch.multiprocessing as multiprocessing
from torch.utils.data import DataLoader
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ from torch.utils.data import SequentialSampler
_remove_worker_pids, _error_if_any_worker_fails from torch.utils.data import RandomSampler
from torch.utils.data.dataloader import DataLoader from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data.dataloader import ManagerWatchdog
from torch.utils.data.dataloader import _pin_memory_loop from torch.utils.data._utils import collate
from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import ExceptionWrapper from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data.dataloader import numpy_type_map from torch.utils.data._utils.worker import ManagerWatchdog
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch from torch._six import queue
from torch.utils.data.dataloader import _SIGCHLD_handler_set
from torch.utils.data.dataloader import _set_SIGCHLD_handler
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try: try:
global _use_shared_memory collate._use_shared_memory = True
_use_shared_memory = True signal_handling._set_worker_signal_handlers()
_set_worker_signal_handlers()
torch.set_num_threads(1) torch.set_num_threads(1)
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
data_queue.cancel_join_thread() data_queue.cancel_join_thread()
if init_fn is not None: if init_fn is not None:
...@@ -55,6 +46,7 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se ...@@ -55,6 +46,7 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se
return return
elif done_event.is_set(): elif done_event.is_set():
continue continue
idx, batch_indices = r idx, batch_indices = r
try: try:
idx_scale = 0 idx_scale = 0
...@@ -68,10 +60,13 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se ...@@ -68,10 +60,13 @@ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, se
data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else: else:
data_queue.put((idx, samples)) data_queue.put((idx, samples))
del samples
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
class _MSDataLoaderIter(_DataLoaderIter): class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader): def __init__(self, loader):
self.dataset = loader.dataset self.dataset = loader.dataset
self.scale = loader.scale self.scale = loader.scale
...@@ -118,6 +113,7 @@ class _MSDataLoaderIter(_DataLoaderIter): ...@@ -118,6 +113,7 @@ class _MSDataLoaderIter(_DataLoaderIter):
i i
) )
) )
w.daemon = True
w.start() w.start()
self.index_queues.append(index_queue) self.index_queues.append(index_queue)
self.workers.append(w) self.workers.append(w)
...@@ -125,7 +121,7 @@ class _MSDataLoaderIter(_DataLoaderIter): ...@@ -125,7 +121,7 @@ class _MSDataLoaderIter(_DataLoaderIter):
if self.pin_memory: if self.pin_memory:
self.data_queue = queue.Queue() self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread( pin_memory_thread = threading.Thread(
target=_pin_memory_loop, target=_utils.pin_memory._pin_memory_loop,
args=( args=(
self.worker_result_queue, self.worker_result_queue,
self.data_queue, self.data_queue,
...@@ -139,35 +135,24 @@ class _MSDataLoaderIter(_DataLoaderIter): ...@@ -139,35 +135,24 @@ class _MSDataLoaderIter(_DataLoaderIter):
else: else:
self.data_queue = self.worker_result_queue self.data_queue = self.worker_result_queue
_update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _utils.signal_handling._set_worker_pids(
_set_SIGCHLD_handler() id(self), tuple(w.pid for w in self.workers)
)
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True self.worker_pids_set = True
for _ in range(2 * self.num_workers): for _ in range(2 * self.num_workers):
self._put_indices() self._put_indices()
class MSDataLoader(DataLoader): class MSDataLoader(DataLoader):
def __init__(
self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
def __init__(self, cfg, *args, **kwargs):
super(MSDataLoader, self).__init__( super(MSDataLoader, self).__init__(
dataset, *args, **kwargs, num_workers=cfg.n_threads
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=args.n_threads,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn
) )
self.scale = cfg.scale
self.scale = args.scale
def __iter__(self): def __iter__(self):
return _MSDataLoaderIter(self) return _MSDataLoaderIter(self)
...@@ -26,7 +26,6 @@ class Trainer(): ...@@ -26,7 +26,6 @@ class Trainer():
self.error_last = 1e8 self.error_last = 1e8
def train(self): def train(self):
self.optimizer.schedule()
self.loss.step() self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1 epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr() lr = self.optimizer.get_lr()
...@@ -68,11 +67,12 @@ class Trainer(): ...@@ -68,11 +67,12 @@ class Trainer():
self.loss.end_log(len(self.loader_train)) self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1] self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self): def test(self):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch() + 1 epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:') self.ckp.write_log('\nEvaluation:')
self.ckp.add_log( self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale)) torch.zeros(1, len(self.loader_test), len(self.scale))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment