diff --git a/src/model/__init__.py b/src/model/__init__.py
index 68e28f76bf50aa090829ee0747b7adb2906b0e85..dca13eae6cf659e6d181d2ff6106d7e8e2ad4d0e 100644
--- a/src/model/__init__.py
+++ b/src/model/__init__.py
@@ -3,6 +3,7 @@ from importlib import import_module
 
 import torch
 import torch.nn as nn
+import torch.nn.parallel as P
 import torch.utils.model_zoo
 
 class Model(nn.Module):
@@ -23,10 +24,8 @@ class Model(nn.Module):
 
         module = import_module('model.' + args.model.lower())
         self.model = module.make_model(args).to(self.device)
-        if args.precision == 'half': self.model.half()
-
-        if not args.cpu and args.n_GPUs > 1:
-            self.model = nn.DataParallel(self.model, range(args.n_GPUs))
+        if args.precision == 'half':
+            self.model.half()
 
         self.load(
             ckp.get_path('model'),
@@ -38,32 +37,26 @@ class Model(nn.Module):
 
     def forward(self, x, idx_scale):
         self.idx_scale = idx_scale
-        target = self.get_model()
-        if hasattr(target, 'set_scale'): target.set_scale(idx_scale)
-        if self.self_ensemble and not self.training:
+        if hasattr(self.model, 'set_scale'):
+            self.model.set_scale(idx_scale)
+
+        if self.training:
+            if self.n_GPUs > 1:
+                return P.data_parallel(self.model, x, range(self.n_GPUs)
+            else:
+                return self.model(x)
+        else:
             if self.chop:
                 forward_function = self.forward_chop
             else:
                 forward_function = self.model.forward
 
-            return self.forward_x8(x, forward_function=forward_function)
-        elif self.chop and not self.training:
-            return self.forward_chop(x)
-        else:
-            return self.model(x)
-
-    def get_model(self):
-        if self.n_GPUs == 1:
-            return self.model
-        else:
-            return self.model.module
-
-    def state_dict(self, **kwargs):
-        target = self.get_model()
-        return target.state_dict(**kwargs)
+            if self.self_ensemble:
+                return self.forward_x8(x, forward_function=forward_function)
+            else:
+                return forward_function(x)
 
     def save(self, apath, epoch, is_best=False):
-        target = self.get_model()
         save_dirs = [os.path.join(apath, 'model_latest.pt')]
 
         if is_best:
@@ -73,7 +66,8 @@ class Model(nn.Module):
                 os.path.join(apath, 'model_{}.pt'.format(epoch))
             )
 
-        for s in save_dirs: torch.save(target.state_dict(), s)
+        for s in save_dirs:
+            torch.save(self.model.state_dict(), s)
 
     def load(self, apath, pre_train='', resume=-1, cpu=False):
         load_from = None
@@ -92,7 +86,7 @@ class Model(nn.Module):
                 dir_model = os.path.join('..', 'models')
                 os.makedirs(dir_model, exist_ok=True)
                 load_from = torch.utils.model_zoo.load_url(
-                    self.get_model().url,
+                    self.model.url,
                     model_dir=dir_model,
                     **kwargs
                 )
@@ -106,7 +100,7 @@ class Model(nn.Module):
             )
 
         if load_from:
-            self.get_model().load_state_dict(load_from, strict=False)
+            self.model.load_state_dict(load_from, strict=False)
 
     def forward_chop(self, *args, shave=10, min_size=160000):
         scale = 1 if self.input_large else self.scale[self.idx_scale]
@@ -129,7 +123,7 @@ class Model(nn.Module):
         if h * w < 4 * min_size:
             for i in range(0, 4, n_GPUs):
                 x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
-                y = self.model(*x)
+                y = P.data_parallel(self.model, *x, range(n_GPUs))
                 if not isinstance(y, list): y = [y]
                 if not y_chops:
                     y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]