diff --git a/src/model/__init__.py b/src/model/__init__.py index ee35f5c932f28ff3bd1a4822ffc3054d5340965d..68e28f76bf50aa090829ee0747b7adb2906b0e85 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -76,12 +76,11 @@ class Model(nn.Module): for s in save_dirs: torch.save(target.state_dict(), s) def load(self, apath, pre_train='', resume=-1, cpu=False): + load_from = None + kwargs = {} if cpu: kwargs = {'map_location': lambda storage, loc: storage} - else: - kwargs = {} - load_from = None if resume == -1: load_from = torch.load( os.path.join(apath, 'model_latest.pt'), @@ -106,61 +105,61 @@ class Model(nn.Module): **kwargs ) - if load_from: self.get_model().load_state_dict(load_from, strict=False) + if load_from: + self.get_model().load_state_dict(load_from, strict=False) def forward_chop(self, *args, shave=10, min_size=160000): - if self.input_large: - scale = 1 - else: - scale = self.scale[self.idx_scale] - + scale = 1 if self.input_large else self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) - _, _, h, w = args[0].size() - h_half, w_half = h // 2, w // 2 - h_size, w_size = h_half + shave, w_half + shave - list_x = [[ - a[:, :, 0:h_size, 0:w_size], - a[:, :, 0:h_size, (w - w_size):w], - a[:, :, (h - h_size):h, 0:w_size], - a[:, :, (h - h_size):h, (w - w_size):w] - ] for a in args] - - list_y = [] - if w_size * h_size < min_size: + # height, width + h, w = args[0].size()[-2:] + + top = slice(0, h//2 + shave) + bottom = slice(h - h//2 - shave, h) + left = slice(0, w//2 + shave) + right = slice(w - w//2 - shave, w) + x_chops = [torch.cat([ + a[..., top, left], + a[..., top, right], + a[..., bottom, left], + a[..., bottom, right] + ]) for a in args] + + y_chops = [] + if h * w < 4 * min_size: for i in range(0, 4, n_GPUs): - x = [torch.cat(_x[i:(i + n_GPUs)], dim=0) for _x in list_x] + x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] y = self.model(*x) if not isinstance(y, list): y = [y] - if not list_y: - list_y = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] + if not y_chops: + y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] else: - for _list_y, _y in zip(list_y, y): - _list_y.extend(_y.chunk(n_GPUs, dim=0)) + for y_chop, _y in zip(y_chops, y): + y_chop.extend(_y.chunk(n_GPUs, dim=0)) else: - for p in zip(*list_x): + for p in zip(*x_chops): y = self.forward_chop(*p, shave=shave, min_size=min_size) if not isinstance(y, list): y = [y] - if not list_y: - list_y = [[_y] for _y in y] + if not y_chops: + y_chops = [[_y] for _y in y] else: - for _list_y, _y in zip(list_y, y): _list_y.append(_y) - - h, w = scale * h, scale * w - h_half, w_half = scale * h_half, scale * w_half - h_size, w_size = scale * h_size, scale * w_size - shave *= scale - - b, c, _, _ = list_y[0][0].size() - y = [_y[0].new(b, c, h, w) for _y in list_y] - for _list_y, _y in zip(list_y, y): - _y[:, :, :h_half, :w_half] \ - = _list_y[0][:, :, :h_half, :w_half] - _y[:, :, :h_half, w_half:] \ - = _list_y[1][:, :, :h_half, (w_size - w + w_half):] - _y[:, :, h_half:, :w_half] \ - = _list_y[2][:, :, (h_size - h + h_half):, :w_half] - _y[:, :, h_half:, w_half:] \ - = _list_y[3][:, :, (h_size - h + h_half):, (w_size - w + w_half):] + for y_chop, _y in zip(y_chops, y): y_chop.append(_y) + + top = slice(0, scale * h//2) + bottom = slice(scale * (h - h//2), scale * h) + bottom_r = slice(scale* (h//2 - h), None) + left = slice(0, scale * w//2) + right = slice(scale * (w - w//2), scale * w) + right_r = slice(scale * w//2, None) + + # batch size, number of color channels + b, c = y_chops[0][0].size()[:-2] + y = [y_chop[0].new(b, c, scale * h, scale * w) for y_chop in y_chops] + for y_chop, _y in zip(y_chops, y): + _y[..., top, left] = y_chop[0][..., top, left] + _y[..., top, right] = y_chop[1][..., top, right_r] + _y[..., bottom, left] = y_chop[2][..., bottom_r, left] + _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] if len(y) == 1: y = y[0] @@ -212,4 +211,3 @@ class Model(nn.Module): if len(y) == 1: y = y[0] return y -