From 60c2ba2f3a6459e4521283b70f22e9c34f32cf4a Mon Sep 17 00:00:00 2001
From: Sanghyun Son <thstkdgus35@snu.ac.kr>
Date: Mon, 5 Nov 2018 12:05:14 +0900
Subject: [PATCH] fix network freeze issue

---
 src/loss/vgg.py     |  3 ++-
 src/model/common.py | 24 +++++++++++++++---------
 2 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/src/loss/vgg.py b/src/loss/vgg.py
index 335716d..42ab9d0 100644
--- a/src/loss/vgg.py
+++ b/src/loss/vgg.py
@@ -18,7 +18,8 @@ class VGG(nn.Module):
         vgg_mean = (0.485, 0.456, 0.406)
         vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
         self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
-        self.vgg.requires_grad = False
+        for p in self.parameters():
+            p.requires_grad = False
 
     def forward(self, sr, hr):
         def _forward(x):
diff --git a/src/model/common.py b/src/model/common.py
index 74ffa37..aeee335 100644
--- a/src/model/common.py
+++ b/src/model/common.py
@@ -18,7 +18,8 @@ class MeanShift(nn.Conv2d):
         std = torch.Tensor(rgb_std)
         self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
         self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
-        self.requires_grad = False
+        for p in self.parameters():
+            p.requires_grad = False
 
 class BasicBlock(nn.Sequential):
     def __init__(
@@ -26,8 +27,11 @@ class BasicBlock(nn.Sequential):
         bn=True, act=nn.ReLU(True)):
 
         m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
-        if bn: m.append(nn.BatchNorm2d(out_channels))
-        if act is not None: m.append(act)
+        if bn:
+            m.append(nn.BatchNorm2d(out_channels))
+        if act is not None:
+            m.append(act)
+
         super(BasicBlock, self).__init__(*m)
 
 class ResBlock(nn.Module):
@@ -39,8 +43,10 @@ class ResBlock(nn.Module):
         m = []
         for i in range(2):
             m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
-            if bn: m.append(nn.BatchNorm2d(n_feats))
-            if i == 0: m.append(act)
+            if bn:
+                m.append(nn.BatchNorm2d(n_feats))
+            if i == 0:
+                m.append(act)
 
         self.body = nn.Sequential(*m)
         self.res_scale = res_scale
@@ -59,8 +65,8 @@ class Upsampler(nn.Sequential):
             for _ in range(int(math.log(scale, 2))):
                 m.append(conv(n_feats, 4 * n_feats, 3, bias))
                 m.append(nn.PixelShuffle(2))
-                if bn: m.append(nn.BatchNorm2d(n_feats))
-
+                if bn:
+                    m.append(nn.BatchNorm2d(n_feats))
                 if act == 'relu':
                     m.append(nn.ReLU(True))
                 elif act == 'prelu':
@@ -69,8 +75,8 @@ class Upsampler(nn.Sequential):
         elif scale == 3:
             m.append(conv(n_feats, 9 * n_feats, 3, bias))
             m.append(nn.PixelShuffle(3))
-            if bn: m.append(nn.BatchNorm2d(n_feats))
-
+            if bn:
+                m.append(nn.BatchNorm2d(n_feats))
             if act == 'relu':
                 m.append(nn.ReLU(True))
             elif act == 'prelu':
-- 
GitLab