Skip to content
Snippets Groups Projects
Commit 2904b24b authored by Sanghyun Son's avatar Sanghyun Son
Browse files

fix dataparallel implementation

parent 32cfb6a4
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
......@@ -3,6 +3,7 @@ from importlib import import_module
import torch
import torch.nn as nn
import torch.nn.parallel as P
import torch.utils.model_zoo
class Model(nn.Module):
......@@ -23,10 +24,8 @@ class Model(nn.Module):
module = import_module('model.' + args.model.lower())
self.model = module.make_model(args).to(self.device)
if args.precision == 'half': self.model.half()
if not args.cpu and args.n_GPUs > 1:
self.model = nn.DataParallel(self.model, range(args.n_GPUs))
if args.precision == 'half':
self.model.half()
self.load(
ckp.get_path('model'),
......@@ -38,32 +37,26 @@ class Model(nn.Module):
def forward(self, x, idx_scale):
self.idx_scale = idx_scale
target = self.get_model()
if hasattr(target, 'set_scale'): target.set_scale(idx_scale)
if self.self_ensemble and not self.training:
if hasattr(self.model, 'set_scale'):
self.model.set_scale(idx_scale)
if self.training:
if self.n_GPUs > 1:
return P.data_parallel(self.model, x, range(self.n_GPUs)
else:
return self.model(x)
else:
if self.chop:
forward_function = self.forward_chop
else:
forward_function = self.model.forward
if self.self_ensemble:
return self.forward_x8(x, forward_function=forward_function)
elif self.chop and not self.training:
return self.forward_chop(x)
else:
return self.model(x)
def get_model(self):
if self.n_GPUs == 1:
return self.model
else:
return self.model.module
def state_dict(self, **kwargs):
target = self.get_model()
return target.state_dict(**kwargs)
return forward_function(x)
def save(self, apath, epoch, is_best=False):
target = self.get_model()
save_dirs = [os.path.join(apath, 'model_latest.pt')]
if is_best:
......@@ -73,7 +66,8 @@ class Model(nn.Module):
os.path.join(apath, 'model_{}.pt'.format(epoch))
)
for s in save_dirs: torch.save(target.state_dict(), s)
for s in save_dirs:
torch.save(self.model.state_dict(), s)
def load(self, apath, pre_train='', resume=-1, cpu=False):
load_from = None
......@@ -92,7 +86,7 @@ class Model(nn.Module):
dir_model = os.path.join('..', 'models')
os.makedirs(dir_model, exist_ok=True)
load_from = torch.utils.model_zoo.load_url(
self.get_model().url,
self.model.url,
model_dir=dir_model,
**kwargs
)
......@@ -106,7 +100,7 @@ class Model(nn.Module):
)
if load_from:
self.get_model().load_state_dict(load_from, strict=False)
self.model.load_state_dict(load_from, strict=False)
def forward_chop(self, *args, shave=10, min_size=160000):
scale = 1 if self.input_large else self.scale[self.idx_scale]
......@@ -129,7 +123,7 @@ class Model(nn.Module):
if h * w < 4 * min_size:
for i in range(0, 4, n_GPUs):
x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
y = self.model(*x)
y = P.data_parallel(self.model, *x, range(n_GPUs))
if not isinstance(y, list): y = [y]
if not y_chops:
y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment