Skip to content
Snippets Groups Projects
Commit 32cfb6a4 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

need to test a new forward_chop

parent 60c2ba2f
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
......@@ -76,12 +76,11 @@ class Model(nn.Module):
for s in save_dirs: torch.save(target.state_dict(), s)
def load(self, apath, pre_train='', resume=-1, cpu=False):
load_from = None
kwargs = {}
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
load_from = None
if resume == -1:
load_from = torch.load(
os.path.join(apath, 'model_latest.pt'),
......@@ -106,61 +105,61 @@ class Model(nn.Module):
**kwargs
)
if load_from: self.get_model().load_state_dict(load_from, strict=False)
if load_from:
self.get_model().load_state_dict(load_from, strict=False)
def forward_chop(self, *args, shave=10, min_size=160000):
if self.input_large:
scale = 1
else:
scale = self.scale[self.idx_scale]
scale = 1 if self.input_large else self.scale[self.idx_scale]
n_GPUs = min(self.n_GPUs, 4)
_, _, h, w = args[0].size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
list_x = [[
a[:, :, 0:h_size, 0:w_size],
a[:, :, 0:h_size, (w - w_size):w],
a[:, :, (h - h_size):h, 0:w_size],
a[:, :, (h - h_size):h, (w - w_size):w]
] for a in args]
list_y = []
if w_size * h_size < min_size:
# height, width
h, w = args[0].size()[-2:]
top = slice(0, h//2 + shave)
bottom = slice(h - h//2 - shave, h)
left = slice(0, w//2 + shave)
right = slice(w - w//2 - shave, w)
x_chops = [torch.cat([
a[..., top, left],
a[..., top, right],
a[..., bottom, left],
a[..., bottom, right]
]) for a in args]
y_chops = []
if h * w < 4 * min_size:
for i in range(0, 4, n_GPUs):
x = [torch.cat(_x[i:(i + n_GPUs)], dim=0) for _x in list_x]
x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
y = self.model(*x)
if not isinstance(y, list): y = [y]
if not list_y:
list_y = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
if not y_chops:
y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
else:
for _list_y, _y in zip(list_y, y):
_list_y.extend(_y.chunk(n_GPUs, dim=0))
for y_chop, _y in zip(y_chops, y):
y_chop.extend(_y.chunk(n_GPUs, dim=0))
else:
for p in zip(*list_x):
for p in zip(*x_chops):
y = self.forward_chop(*p, shave=shave, min_size=min_size)
if not isinstance(y, list): y = [y]
if not list_y:
list_y = [[_y] for _y in y]
if not y_chops:
y_chops = [[_y] for _y in y]
else:
for _list_y, _y in zip(list_y, y): _list_y.append(_y)
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
b, c, _, _ = list_y[0][0].size()
y = [_y[0].new(b, c, h, w) for _y in list_y]
for _list_y, _y in zip(list_y, y):
_y[:, :, :h_half, :w_half] \
= _list_y[0][:, :, :h_half, :w_half]
_y[:, :, :h_half, w_half:] \
= _list_y[1][:, :, :h_half, (w_size - w + w_half):]
_y[:, :, h_half:, :w_half] \
= _list_y[2][:, :, (h_size - h + h_half):, :w_half]
_y[:, :, h_half:, w_half:] \
= _list_y[3][:, :, (h_size - h + h_half):, (w_size - w + w_half):]
for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
top = slice(0, scale * h//2)
bottom = slice(scale * (h - h//2), scale * h)
bottom_r = slice(scale* (h//2 - h), None)
left = slice(0, scale * w//2)
right = slice(scale * (w - w//2), scale * w)
right_r = slice(scale * w//2, None)
# batch size, number of color channels
b, c = y_chops[0][0].size()[:-2]
y = [y_chop[0].new(b, c, scale * h, scale * w) for y_chop in y_chops]
for y_chop, _y in zip(y_chops, y):
_y[..., top, left] = y_chop[0][..., top, left]
_y[..., top, right] = y_chop[1][..., top, right_r]
_y[..., bottom, left] = y_chop[2][..., bottom_r, left]
_y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
if len(y) == 1: y = y[0]
......@@ -212,4 +211,3 @@ class Model(nn.Module):
if len(y) == 1: y = y[0]
return y
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment