From ea971d90bbf48ce47a00936b6a1356fd76556026 Mon Sep 17 00:00:00 2001 From: Sanghyun Son <thstkdgus35@snu.ac.kr> Date: Fri, 5 Oct 2018 14:06:28 +0900 Subject: [PATCH] bug fix in gan --- src/demo.sh | 4 ++-- src/loss/discriminator.py | 33 ++++++++++++++++++++------------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/demo.sh b/src/demo.sh index cab75ea..4656546 100644 --- a/src/demo.sh +++ b/src/demo.sh @@ -1,6 +1,6 @@ -# EDSR baseline model (x2) +# EDSR baseline model (x2) + JPEG augmentation #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset -python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 +#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 # EDSR baseline model (x3) - from EDSR baseline model (x2) #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] diff --git a/src/loss/discriminator.py b/src/loss/discriminator.py index feb21ed..53fff1a 100644 --- a/src/loss/discriminator.py +++ b/src/loss/discriminator.py @@ -6,16 +6,25 @@ class Discriminator(nn.Module): def __init__(self, args, gan_type='GAN'): super(Discriminator, self).__init__() - in_channels = 3 + in_channels = args.n_colors out_channels = 64 depth = 7 - #bn = not gan_type == 'WGAN_GP' - bn = True - act = nn.LeakyReLU(negative_slope=0.2, inplace=True) - m_features = [ - common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act) - ] + def _block(_in_channels, _out_channels, stride=1): + return nn.Sequential( + nn.Conv2d( + _in_channels, + _out_channels, + 3, + padding=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(_out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True) + ) + + m_features = [_block(in_channels, out_channels)] for i in range(depth): in_channels = out_channels if i % 2 == 1: @@ -23,18 +32,16 @@ class Discriminator(nn.Module): out_channels *= 2 else: stride = 2 - m_features.append(common.BasicBlock( - in_channels, out_channels, 3, stride=stride, bn=bn, act=act - )) - - self.features = nn.Sequential(*m_features) + m_features.append(_block(in_channels, out_channels, stride=stride)) patch_size = args.patch_size // (2**((depth + 1) // 2)) m_classifier = [ nn.Linear(out_channels * patch_size**2, 1024), - act, + nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Linear(1024, 1) ] + + self.features = nn.Sequential(*m_features) self.classifier = nn.Sequential(*m_classifier) def forward(self, x): -- GitLab