Skip to content
Snippets Groups Projects
Commit 32cfb6a4 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

need to test a new forward_chop

parent 60c2ba2f
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -76,12 +76,11 @@ class Model(nn.Module): ...@@ -76,12 +76,11 @@ class Model(nn.Module):
for s in save_dirs: torch.save(target.state_dict(), s) for s in save_dirs: torch.save(target.state_dict(), s)
def load(self, apath, pre_train='', resume=-1, cpu=False): def load(self, apath, pre_train='', resume=-1, cpu=False):
load_from = None
kwargs = {}
if cpu: if cpu:
kwargs = {'map_location': lambda storage, loc: storage} kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
load_from = None
if resume == -1: if resume == -1:
load_from = torch.load( load_from = torch.load(
os.path.join(apath, 'model_latest.pt'), os.path.join(apath, 'model_latest.pt'),
...@@ -106,61 +105,61 @@ class Model(nn.Module): ...@@ -106,61 +105,61 @@ class Model(nn.Module):
**kwargs **kwargs
) )
if load_from: self.get_model().load_state_dict(load_from, strict=False) if load_from:
self.get_model().load_state_dict(load_from, strict=False)
def forward_chop(self, *args, shave=10, min_size=160000): def forward_chop(self, *args, shave=10, min_size=160000):
if self.input_large: scale = 1 if self.input_large else self.scale[self.idx_scale]
scale = 1
else:
scale = self.scale[self.idx_scale]
n_GPUs = min(self.n_GPUs, 4) n_GPUs = min(self.n_GPUs, 4)
_, _, h, w = args[0].size() # height, width
h_half, w_half = h // 2, w // 2 h, w = args[0].size()[-2:]
h_size, w_size = h_half + shave, w_half + shave
list_x = [[ top = slice(0, h//2 + shave)
a[:, :, 0:h_size, 0:w_size], bottom = slice(h - h//2 - shave, h)
a[:, :, 0:h_size, (w - w_size):w], left = slice(0, w//2 + shave)
a[:, :, (h - h_size):h, 0:w_size], right = slice(w - w//2 - shave, w)
a[:, :, (h - h_size):h, (w - w_size):w] x_chops = [torch.cat([
] for a in args] a[..., top, left],
a[..., top, right],
list_y = [] a[..., bottom, left],
if w_size * h_size < min_size: a[..., bottom, right]
]) for a in args]
y_chops = []
if h * w < 4 * min_size:
for i in range(0, 4, n_GPUs): for i in range(0, 4, n_GPUs):
x = [torch.cat(_x[i:(i + n_GPUs)], dim=0) for _x in list_x] x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
y = self.model(*x) y = self.model(*x)
if not isinstance(y, list): y = [y] if not isinstance(y, list): y = [y]
if not list_y: if not y_chops:
list_y = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
else: else:
for _list_y, _y in zip(list_y, y): for y_chop, _y in zip(y_chops, y):
_list_y.extend(_y.chunk(n_GPUs, dim=0)) y_chop.extend(_y.chunk(n_GPUs, dim=0))
else: else:
for p in zip(*list_x): for p in zip(*x_chops):
y = self.forward_chop(*p, shave=shave, min_size=min_size) y = self.forward_chop(*p, shave=shave, min_size=min_size)
if not isinstance(y, list): y = [y] if not isinstance(y, list): y = [y]
if not list_y: if not y_chops:
list_y = [[_y] for _y in y] y_chops = [[_y] for _y in y]
else: else:
for _list_y, _y in zip(list_y, y): _list_y.append(_y) for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
h, w = scale * h, scale * w top = slice(0, scale * h//2)
h_half, w_half = scale * h_half, scale * w_half bottom = slice(scale * (h - h//2), scale * h)
h_size, w_size = scale * h_size, scale * w_size bottom_r = slice(scale* (h//2 - h), None)
shave *= scale left = slice(0, scale * w//2)
right = slice(scale * (w - w//2), scale * w)
b, c, _, _ = list_y[0][0].size() right_r = slice(scale * w//2, None)
y = [_y[0].new(b, c, h, w) for _y in list_y]
for _list_y, _y in zip(list_y, y): # batch size, number of color channels
_y[:, :, :h_half, :w_half] \ b, c = y_chops[0][0].size()[:-2]
= _list_y[0][:, :, :h_half, :w_half] y = [y_chop[0].new(b, c, scale * h, scale * w) for y_chop in y_chops]
_y[:, :, :h_half, w_half:] \ for y_chop, _y in zip(y_chops, y):
= _list_y[1][:, :, :h_half, (w_size - w + w_half):] _y[..., top, left] = y_chop[0][..., top, left]
_y[:, :, h_half:, :w_half] \ _y[..., top, right] = y_chop[1][..., top, right_r]
= _list_y[2][:, :, (h_size - h + h_half):, :w_half] _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
_y[:, :, h_half:, w_half:] \ _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
= _list_y[3][:, :, (h_size - h + h_half):, (w_size - w + w_half):]
if len(y) == 1: y = y[0] if len(y) == 1: y = y[0]
...@@ -212,4 +211,3 @@ class Model(nn.Module): ...@@ -212,4 +211,3 @@ class Model(nn.Module):
if len(y) == 1: y = y[0] if len(y) == 1: y = y[0]
return y return y
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment