From ea971d90bbf48ce47a00936b6a1356fd76556026 Mon Sep 17 00:00:00 2001
From: Sanghyun Son <thstkdgus35@snu.ac.kr>
Date: Fri, 5 Oct 2018 14:06:28 +0900
Subject: [PATCH] bug fix in gan

---
 src/demo.sh               |  4 ++--
 src/loss/discriminator.py | 33 ++++++++++++++++++++-------------
 2 files changed, 22 insertions(+), 15 deletions(-)

diff --git a/src/demo.sh b/src/demo.sh
index cab75ea..4656546 100644
--- a/src/demo.sh
+++ b/src/demo.sh
@@ -1,6 +1,6 @@
-# EDSR baseline model (x2)
+# EDSR baseline model (x2) + JPEG augmentation
 #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset
-python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
+#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
 
 # EDSR baseline model (x3) - from EDSR baseline model (x2)
 #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
diff --git a/src/loss/discriminator.py b/src/loss/discriminator.py
index feb21ed..53fff1a 100644
--- a/src/loss/discriminator.py
+++ b/src/loss/discriminator.py
@@ -6,16 +6,25 @@ class Discriminator(nn.Module):
     def __init__(self, args, gan_type='GAN'):
         super(Discriminator, self).__init__()
 
-        in_channels = 3
+        in_channels = args.n_colors
         out_channels = 64
         depth = 7
-        #bn = not gan_type == 'WGAN_GP'
-        bn = True
-        act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 
-        m_features = [
-            common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
-        ]
+        def _block(_in_channels, _out_channels, stride=1):
+            return nn.Sequential(
+                nn.Conv2d(
+                    _in_channels,
+                    _out_channels,
+                    3,
+                    padding=1,
+                    stride=stride,
+                    bias=False
+                ),
+                nn.BatchNorm2d(_out_channels),
+                nn.LeakyReLU(negative_slope=0.2, inplace=True)
+            )
+
+        m_features = [_block(in_channels, out_channels)]
         for i in range(depth):
             in_channels = out_channels
             if i % 2 == 1:
@@ -23,18 +32,16 @@ class Discriminator(nn.Module):
                 out_channels *= 2
             else:
                 stride = 2
-            m_features.append(common.BasicBlock(
-                in_channels, out_channels, 3, stride=stride, bn=bn, act=act
-            ))
-
-        self.features = nn.Sequential(*m_features)
+            m_features.append(_block(in_channels, out_channels, stride=stride))
 
         patch_size = args.patch_size // (2**((depth + 1) // 2))
         m_classifier = [
             nn.Linear(out_channels * patch_size**2, 1024),
-            act,
+            nn.LeakyReLU(negative_slope=0.2, inplace=True),
             nn.Linear(1024, 1)
         ]
+
+        self.features = nn.Sequential(*m_features)
         self.classifier = nn.Sequential(*m_classifier)
 
     def forward(self, x):
-- 
GitLab