Skip to content
Snippets Groups Projects
Select Git revision
  • ea971d90bbf48ce47a00936b6a1356fd76556026
  • main default
2 results

discriminator.py

Blame
  • discriminator.py 1.53 KiB
    from model import common
    
    import torch.nn as nn
    
    class Discriminator(nn.Module):
        def __init__(self, args, gan_type='GAN'):
            super(Discriminator, self).__init__()
    
            in_channels = args.n_colors
            out_channels = 64
            depth = 7
    
            def _block(_in_channels, _out_channels, stride=1):
                return nn.Sequential(
                    nn.Conv2d(
                        _in_channels,
                        _out_channels,
                        3,
                        padding=1,
                        stride=stride,
                        bias=False
                    ),
                    nn.BatchNorm2d(_out_channels),
                    nn.LeakyReLU(negative_slope=0.2, inplace=True)
                )
    
            m_features = [_block(in_channels, out_channels)]
            for i in range(depth):
                in_channels = out_channels
                if i % 2 == 1:
                    stride = 1
                    out_channels *= 2
                else:
                    stride = 2
                m_features.append(_block(in_channels, out_channels, stride=stride))
    
            patch_size = args.patch_size // (2**((depth + 1) // 2))
            m_classifier = [
                nn.Linear(out_channels * patch_size**2, 1024),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(1024, 1)
            ]
    
            self.features = nn.Sequential(*m_features)
            self.classifier = nn.Sequential(*m_classifier)
    
        def forward(self, x):
            features = self.features(x)
            output = self.classifier(features.view(features.size(0), -1))
    
            return output