diff --git a/src/loss/vgg.py b/src/loss/vgg.py index 335716d56b6bbd0876555f9e297480353288a0d9..42ab9d0a9914076e32b6400db2feb03c84574cbc 100644 --- a/src/loss/vgg.py +++ b/src/loss/vgg.py @@ -18,7 +18,8 @@ class VGG(nn.Module): vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) - self.vgg.requires_grad = False + for p in self.parameters(): + p.requires_grad = False def forward(self, sr, hr): def _forward(x): diff --git a/src/model/common.py b/src/model/common.py index 74ffa371a680e80bcea79853f4d948c433c72721..aeee33512680a5c1437f9ad68be18bc2436fca61 100644 --- a/src/model/common.py +++ b/src/model/common.py @@ -18,7 +18,8 @@ class MeanShift(nn.Conv2d): std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std - self.requires_grad = False + for p in self.parameters(): + p.requires_grad = False class BasicBlock(nn.Sequential): def __init__( @@ -26,8 +27,11 @@ class BasicBlock(nn.Sequential): bn=True, act=nn.ReLU(True)): m = [conv(in_channels, out_channels, kernel_size, bias=bias)] - if bn: m.append(nn.BatchNorm2d(out_channels)) - if act is not None: m.append(act) + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): @@ -39,8 +43,10 @@ class ResBlock(nn.Module): m = [] for i in range(2): m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - if i == 0: m.append(act) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale @@ -59,8 +65,8 @@ class Upsampler(nn.Sequential): for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(nn.PixelShuffle(2)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - + if bn: + m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': @@ -69,8 +75,8 @@ class Upsampler(nn.Sequential): elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(nn.PixelShuffle(3)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - + if bn: + m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu':