Skip to content
Snippets Groups Projects
Commit 60c2ba2f authored by Sanghyun Son's avatar Sanghyun Son
Browse files

fix network freeze issue

parent b13e310f
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -18,7 +18,8 @@ class VGG(nn.Module): ...@@ -18,7 +18,8 @@ class VGG(nn.Module):
vgg_mean = (0.485, 0.456, 0.406) vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
self.vgg.requires_grad = False for p in self.parameters():
p.requires_grad = False
def forward(self, sr, hr): def forward(self, sr, hr):
def _forward(x): def _forward(x):
......
...@@ -18,7 +18,8 @@ class MeanShift(nn.Conv2d): ...@@ -18,7 +18,8 @@ class MeanShift(nn.Conv2d):
std = torch.Tensor(rgb_std) std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
self.requires_grad = False for p in self.parameters():
p.requires_grad = False
class BasicBlock(nn.Sequential): class BasicBlock(nn.Sequential):
def __init__( def __init__(
...@@ -26,8 +27,11 @@ class BasicBlock(nn.Sequential): ...@@ -26,8 +27,11 @@ class BasicBlock(nn.Sequential):
bn=True, act=nn.ReLU(True)): bn=True, act=nn.ReLU(True)):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)] m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn: m.append(nn.BatchNorm2d(out_channels)) if bn:
if act is not None: m.append(act) m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m) super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module): class ResBlock(nn.Module):
...@@ -39,8 +43,10 @@ class ResBlock(nn.Module): ...@@ -39,8 +43,10 @@ class ResBlock(nn.Module):
m = [] m = []
for i in range(2): for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feats)) if bn:
if i == 0: m.append(act) m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m) self.body = nn.Sequential(*m)
self.res_scale = res_scale self.res_scale = res_scale
...@@ -59,8 +65,8 @@ class Upsampler(nn.Sequential): ...@@ -59,8 +65,8 @@ class Upsampler(nn.Sequential):
for _ in range(int(math.log(scale, 2))): for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2)) m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2d(n_feats)) if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu': if act == 'relu':
m.append(nn.ReLU(True)) m.append(nn.ReLU(True))
elif act == 'prelu': elif act == 'prelu':
...@@ -69,8 +75,8 @@ class Upsampler(nn.Sequential): ...@@ -69,8 +75,8 @@ class Upsampler(nn.Sequential):
elif scale == 3: elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3)) m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2d(n_feats)) if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu': if act == 'relu':
m.append(nn.ReLU(True)) m.append(nn.ReLU(True))
elif act == 'prelu': elif act == 'prelu':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment