From 60c2ba2f3a6459e4521283b70f22e9c34f32cf4a Mon Sep 17 00:00:00 2001 From: Sanghyun Son <thstkdgus35@snu.ac.kr> Date: Mon, 5 Nov 2018 12:05:14 +0900 Subject: [PATCH] fix network freeze issue --- src/loss/vgg.py | 3 ++- src/model/common.py | 24 +++++++++++++++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/loss/vgg.py b/src/loss/vgg.py index 335716d..42ab9d0 100644 --- a/src/loss/vgg.py +++ b/src/loss/vgg.py @@ -18,7 +18,8 @@ class VGG(nn.Module): vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) - self.vgg.requires_grad = False + for p in self.parameters(): + p.requires_grad = False def forward(self, sr, hr): def _forward(x): diff --git a/src/model/common.py b/src/model/common.py index 74ffa37..aeee335 100644 --- a/src/model/common.py +++ b/src/model/common.py @@ -18,7 +18,8 @@ class MeanShift(nn.Conv2d): std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std - self.requires_grad = False + for p in self.parameters(): + p.requires_grad = False class BasicBlock(nn.Sequential): def __init__( @@ -26,8 +27,11 @@ class BasicBlock(nn.Sequential): bn=True, act=nn.ReLU(True)): m = [conv(in_channels, out_channels, kernel_size, bias=bias)] - if bn: m.append(nn.BatchNorm2d(out_channels)) - if act is not None: m.append(act) + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): @@ -39,8 +43,10 @@ class ResBlock(nn.Module): m = [] for i in range(2): m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - if i == 0: m.append(act) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale @@ -59,8 +65,8 @@ class Upsampler(nn.Sequential): for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(nn.PixelShuffle(2)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - + if bn: + m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': @@ -69,8 +75,8 @@ class Upsampler(nn.Sequential): elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(nn.PixelShuffle(3)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - + if bn: + m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': -- GitLab