diff --git a/src/model/__init__.py b/src/model/__init__.py
index ee35f5c932f28ff3bd1a4822ffc3054d5340965d..68e28f76bf50aa090829ee0747b7adb2906b0e85 100644
--- a/src/model/__init__.py
+++ b/src/model/__init__.py
@@ -76,12 +76,11 @@ class Model(nn.Module):
         for s in save_dirs: torch.save(target.state_dict(), s)
 
     def load(self, apath, pre_train='', resume=-1, cpu=False):
+        load_from = None
+        kwargs = {}
         if cpu:
             kwargs = {'map_location': lambda storage, loc: storage}
-        else:
-            kwargs = {}
 
-        load_from = None
         if resume == -1:
             load_from = torch.load(
                 os.path.join(apath, 'model_latest.pt'),
@@ -106,61 +105,61 @@ class Model(nn.Module):
                 **kwargs
             )
 
-        if load_from: self.get_model().load_state_dict(load_from, strict=False)
+        if load_from:
+            self.get_model().load_state_dict(load_from, strict=False)
 
     def forward_chop(self, *args, shave=10, min_size=160000):
-        if self.input_large:
-            scale = 1
-        else:
-            scale = self.scale[self.idx_scale]
-
+        scale = 1 if self.input_large else self.scale[self.idx_scale]
         n_GPUs = min(self.n_GPUs, 4)
-        _, _, h, w = args[0].size()
-        h_half, w_half = h // 2, w // 2
-        h_size, w_size = h_half + shave, w_half + shave
-        list_x = [[
-            a[:, :, 0:h_size, 0:w_size],
-            a[:, :, 0:h_size, (w - w_size):w],
-            a[:, :, (h - h_size):h, 0:w_size],
-            a[:, :, (h - h_size):h, (w - w_size):w]
-        ] for a in args]
-
-        list_y = []
-        if w_size * h_size < min_size:
+        # height, width
+        h, w = args[0].size()[-2:]
+
+        top = slice(0, h//2 + shave)
+        bottom = slice(h - h//2 - shave, h)
+        left = slice(0, w//2 + shave)
+        right = slice(w - w//2 - shave, w)
+        x_chops = [torch.cat([
+            a[..., top, left],
+            a[..., top, right],
+            a[..., bottom, left],
+            a[..., bottom, right]
+        ]) for a in args]
+
+        y_chops = []
+        if h * w < 4 * min_size:
             for i in range(0, 4, n_GPUs):
-                x = [torch.cat(_x[i:(i + n_GPUs)], dim=0) for _x in list_x]
+                x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
                 y = self.model(*x)
                 if not isinstance(y, list): y = [y]
-                if not list_y:
-                    list_y = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
+                if not y_chops:
+                    y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
                 else:
-                    for _list_y, _y in zip(list_y, y):
-                        _list_y.extend(_y.chunk(n_GPUs, dim=0))
+                    for y_chop, _y in zip(y_chops, y):
+                        y_chop.extend(_y.chunk(n_GPUs, dim=0))
         else:
-            for p in zip(*list_x):
+            for p in zip(*x_chops):
                 y = self.forward_chop(*p, shave=shave, min_size=min_size)
                 if not isinstance(y, list): y = [y]
-                if not list_y:
-                    list_y = [[_y] for _y in y]
+                if not y_chops:
+                    y_chops = [[_y] for _y in y]
                 else:
-                    for _list_y, _y in zip(list_y, y): _list_y.append(_y)
-
-        h, w = scale * h, scale * w
-        h_half, w_half = scale * h_half, scale * w_half
-        h_size, w_size = scale * h_size, scale * w_size
-        shave *= scale
-
-        b, c, _, _ = list_y[0][0].size()
-        y = [_y[0].new(b, c, h, w) for _y in list_y]
-        for _list_y, _y in zip(list_y, y):
-            _y[:, :, :h_half, :w_half] \
-                = _list_y[0][:, :, :h_half, :w_half]
-            _y[:, :, :h_half, w_half:] \
-                = _list_y[1][:, :, :h_half, (w_size - w + w_half):]
-            _y[:, :, h_half:, :w_half] \
-                = _list_y[2][:, :, (h_size - h + h_half):, :w_half]
-            _y[:, :, h_half:, w_half:] \
-                = _list_y[3][:, :, (h_size - h + h_half):, (w_size - w + w_half):]
+                    for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
+
+        top = slice(0, scale * h//2)
+        bottom = slice(scale * (h - h//2), scale * h)
+        bottom_r = slice(scale* (h//2 - h), None)
+        left = slice(0, scale * w//2)
+        right = slice(scale * (w - w//2), scale * w)
+        right_r = slice(scale * w//2, None)
+
+        # batch size, number of color channels
+        b, c = y_chops[0][0].size()[:-2]
+        y = [y_chop[0].new(b, c, scale * h, scale * w) for y_chop in y_chops]
+        for y_chop, _y in zip(y_chops, y):
+            _y[..., top, left] = y_chop[0][..., top, left]
+            _y[..., top, right] = y_chop[1][..., top, right_r]
+            _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
+            _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
 
         if len(y) == 1: y = y[0]
 
@@ -212,4 +211,3 @@ class Model(nn.Module):
         if len(y) == 1: y = y[0]
 
         return y
-