diff --git a/src/model/__init__.py b/src/model/__init__.py index 68e28f76bf50aa090829ee0747b7adb2906b0e85..dca13eae6cf659e6d181d2ff6106d7e8e2ad4d0e 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -3,6 +3,7 @@ from importlib import import_module import torch import torch.nn as nn +import torch.nn.parallel as P import torch.utils.model_zoo class Model(nn.Module): @@ -23,10 +24,8 @@ class Model(nn.Module): module = import_module('model.' + args.model.lower()) self.model = module.make_model(args).to(self.device) - if args.precision == 'half': self.model.half() - - if not args.cpu and args.n_GPUs > 1: - self.model = nn.DataParallel(self.model, range(args.n_GPUs)) + if args.precision == 'half': + self.model.half() self.load( ckp.get_path('model'), @@ -38,32 +37,26 @@ class Model(nn.Module): def forward(self, x, idx_scale): self.idx_scale = idx_scale - target = self.get_model() - if hasattr(target, 'set_scale'): target.set_scale(idx_scale) - if self.self_ensemble and not self.training: + if hasattr(self.model, 'set_scale'): + self.model.set_scale(idx_scale) + + if self.training: + if self.n_GPUs > 1: + return P.data_parallel(self.model, x, range(self.n_GPUs) + else: + return self.model(x) + else: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward - return self.forward_x8(x, forward_function=forward_function) - elif self.chop and not self.training: - return self.forward_chop(x) - else: - return self.model(x) - - def get_model(self): - if self.n_GPUs == 1: - return self.model - else: - return self.model.module - - def state_dict(self, **kwargs): - target = self.get_model() - return target.state_dict(**kwargs) + if self.self_ensemble: + return self.forward_x8(x, forward_function=forward_function) + else: + return forward_function(x) def save(self, apath, epoch, is_best=False): - target = self.get_model() save_dirs = [os.path.join(apath, 'model_latest.pt')] if is_best: @@ -73,7 +66,8 @@ class Model(nn.Module): os.path.join(apath, 'model_{}.pt'.format(epoch)) ) - for s in save_dirs: torch.save(target.state_dict(), s) + for s in save_dirs: + torch.save(self.model.state_dict(), s) def load(self, apath, pre_train='', resume=-1, cpu=False): load_from = None @@ -92,7 +86,7 @@ class Model(nn.Module): dir_model = os.path.join('..', 'models') os.makedirs(dir_model, exist_ok=True) load_from = torch.utils.model_zoo.load_url( - self.get_model().url, + self.model.url, model_dir=dir_model, **kwargs ) @@ -106,7 +100,7 @@ class Model(nn.Module): ) if load_from: - self.get_model().load_state_dict(load_from, strict=False) + self.model.load_state_dict(load_from, strict=False) def forward_chop(self, *args, shave=10, min_size=160000): scale = 1 if self.input_large else self.scale[self.idx_scale] @@ -129,7 +123,7 @@ class Model(nn.Module): if h * w < 4 * min_size: for i in range(0, 4, n_GPUs): x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] - y = self.model(*x) + y = P.data_parallel(self.model, *x, range(n_GPUs)) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]