adversarial.py 4.29 KiB
import utility
from types import SimpleNamespace
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.dis = discriminator.Discriminator(args)
if gan_type == 'WGAN_GP':
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else:
optim_args = args
self.optimizer = utility.make_optimizer(optim_args, self.dis)
def forward(self, fake, real):
# updating discriminator...
self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k):
self.optimizer.zero_grad()
# d: B x 1 tensor
d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type == 'GAN':
loss_d = self.bce(d_real, d_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.dis(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update
self.loss += loss_d.item()
loss_d.backward(retain_graph=retain_graph)
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.dis.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
# updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type == 'GAN':
label_real = torch.ones_like(d_fake_bp)
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
loss_g = self.bce(better_fake, better_real)
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py