Skip to content
Snippets Groups Projects
Commit ea971d90 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

bug fix in gan

parent a91ccea9
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
# EDSR baseline model (x2)
# EDSR baseline model (x2) + JPEG augmentation
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset
python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
# EDSR baseline model (x3) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
......
......@@ -6,16 +6,25 @@ class Discriminator(nn.Module):
def __init__(self, args, gan_type='GAN'):
super(Discriminator, self).__init__()
in_channels = 3
in_channels = args.n_colors
out_channels = 64
depth = 7
#bn = not gan_type == 'WGAN_GP'
bn = True
act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
m_features = [
common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
]
def _block(_in_channels, _out_channels, stride=1):
return nn.Sequential(
nn.Conv2d(
_in_channels,
_out_channels,
3,
padding=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
m_features = [_block(in_channels, out_channels)]
for i in range(depth):
in_channels = out_channels
if i % 2 == 1:
......@@ -23,18 +32,16 @@ class Discriminator(nn.Module):
out_channels *= 2
else:
stride = 2
m_features.append(common.BasicBlock(
in_channels, out_channels, 3, stride=stride, bn=bn, act=act
))
self.features = nn.Sequential(*m_features)
m_features.append(_block(in_channels, out_channels, stride=stride))
patch_size = args.patch_size // (2**((depth + 1) // 2))
m_classifier = [
nn.Linear(out_channels * patch_size**2, 1024),
act,
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(1024, 1)
]
self.features = nn.Sequential(*m_features)
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment