Skip to content
Snippets Groups Projects
Commit 2904b24b authored by Sanghyun Son's avatar Sanghyun Son
Browse files

fix dataparallel implementation

parent 32cfb6a4
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -3,6 +3,7 @@ from importlib import import_module ...@@ -3,6 +3,7 @@ from importlib import import_module
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel as P
import torch.utils.model_zoo import torch.utils.model_zoo
class Model(nn.Module): class Model(nn.Module):
...@@ -23,10 +24,8 @@ class Model(nn.Module): ...@@ -23,10 +24,8 @@ class Model(nn.Module):
module = import_module('model.' + args.model.lower()) module = import_module('model.' + args.model.lower())
self.model = module.make_model(args).to(self.device) self.model = module.make_model(args).to(self.device)
if args.precision == 'half': self.model.half() if args.precision == 'half':
self.model.half()
if not args.cpu and args.n_GPUs > 1:
self.model = nn.DataParallel(self.model, range(args.n_GPUs))
self.load( self.load(
ckp.get_path('model'), ckp.get_path('model'),
...@@ -38,32 +37,26 @@ class Model(nn.Module): ...@@ -38,32 +37,26 @@ class Model(nn.Module):
def forward(self, x, idx_scale): def forward(self, x, idx_scale):
self.idx_scale = idx_scale self.idx_scale = idx_scale
target = self.get_model() if hasattr(self.model, 'set_scale'):
if hasattr(target, 'set_scale'): target.set_scale(idx_scale) self.model.set_scale(idx_scale)
if self.self_ensemble and not self.training:
if self.training:
if self.n_GPUs > 1:
return P.data_parallel(self.model, x, range(self.n_GPUs)
else:
return self.model(x)
else:
if self.chop: if self.chop:
forward_function = self.forward_chop forward_function = self.forward_chop
else: else:
forward_function = self.model.forward forward_function = self.model.forward
if self.self_ensemble:
return self.forward_x8(x, forward_function=forward_function) return self.forward_x8(x, forward_function=forward_function)
elif self.chop and not self.training:
return self.forward_chop(x)
else: else:
return self.model(x) return forward_function(x)
def get_model(self):
if self.n_GPUs == 1:
return self.model
else:
return self.model.module
def state_dict(self, **kwargs):
target = self.get_model()
return target.state_dict(**kwargs)
def save(self, apath, epoch, is_best=False): def save(self, apath, epoch, is_best=False):
target = self.get_model()
save_dirs = [os.path.join(apath, 'model_latest.pt')] save_dirs = [os.path.join(apath, 'model_latest.pt')]
if is_best: if is_best:
...@@ -73,7 +66,8 @@ class Model(nn.Module): ...@@ -73,7 +66,8 @@ class Model(nn.Module):
os.path.join(apath, 'model_{}.pt'.format(epoch)) os.path.join(apath, 'model_{}.pt'.format(epoch))
) )
for s in save_dirs: torch.save(target.state_dict(), s) for s in save_dirs:
torch.save(self.model.state_dict(), s)
def load(self, apath, pre_train='', resume=-1, cpu=False): def load(self, apath, pre_train='', resume=-1, cpu=False):
load_from = None load_from = None
...@@ -92,7 +86,7 @@ class Model(nn.Module): ...@@ -92,7 +86,7 @@ class Model(nn.Module):
dir_model = os.path.join('..', 'models') dir_model = os.path.join('..', 'models')
os.makedirs(dir_model, exist_ok=True) os.makedirs(dir_model, exist_ok=True)
load_from = torch.utils.model_zoo.load_url( load_from = torch.utils.model_zoo.load_url(
self.get_model().url, self.model.url,
model_dir=dir_model, model_dir=dir_model,
**kwargs **kwargs
) )
...@@ -106,7 +100,7 @@ class Model(nn.Module): ...@@ -106,7 +100,7 @@ class Model(nn.Module):
) )
if load_from: if load_from:
self.get_model().load_state_dict(load_from, strict=False) self.model.load_state_dict(load_from, strict=False)
def forward_chop(self, *args, shave=10, min_size=160000): def forward_chop(self, *args, shave=10, min_size=160000):
scale = 1 if self.input_large else self.scale[self.idx_scale] scale = 1 if self.input_large else self.scale[self.idx_scale]
...@@ -129,7 +123,7 @@ class Model(nn.Module): ...@@ -129,7 +123,7 @@ class Model(nn.Module):
if h * w < 4 * min_size: if h * w < 4 * min_size:
for i in range(0, 4, n_GPUs): for i in range(0, 4, n_GPUs):
x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
y = self.model(*x) y = P.data_parallel(self.model, *x, range(n_GPUs))
if not isinstance(y, list): y = [y] if not isinstance(y, list): y = [y]
if not y_chops: if not y_chops:
y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment