diff --git a/src/loss/__init__.py b/src/loss/__init__.py index 6d7c21e2f9dbf94a20fb6c74ede439711c44a15c..8171a5aacac117208fba1ccfe43c69885de7728c 100644 --- a/src/loss/__init__.py +++ b/src/loss/__init__.py @@ -1,143 +1,143 @@ -import os -from importlib import import_module - -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt - -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -class Loss(nn.modules.loss._Loss): - def __init__(self, args, ckp): - super(Loss, self).__init__() - print('Preparing loss function:') - - self.n_GPUs = args.n_GPUs - self.loss = [] - self.loss_module = nn.ModuleList() - for loss in args.loss.split('+'): - weight, loss_type = loss.split('*') - if loss_type == 'MSE': - loss_function = nn.MSELoss() - elif loss_type == 'L1': - loss_function = nn.L1Loss() - elif loss_type.find('VGG') >= 0: - module = import_module('loss.vgg') - loss_function = getattr(module, 'VGG')( - loss_type[3:], - rgb_range=args.rgb_range - ) - elif loss_type.find('GAN') >= 0: - module = import_module('loss.adversarial') - loss_function = getattr(module, 'Adversarial')( - args, - loss_type - ) - - self.loss.append({ - 'type': loss_type, - 'weight': float(weight), - 'function': loss_function} - ) - if loss_type.find('GAN') >= 0: - self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) - - if len(self.loss) > 1: - self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) - - for l in self.loss: - if l['function'] is not None: - print('{:.3f} * {}'.format(l['weight'], l['type'])) - self.loss_module.append(l['function']) - - self.log = torch.Tensor() - - device = torch.device('cpu' if args.cpu else 'cuda') - self.loss_module.to(device) - if args.precision == 'half': self.loss_module.half() - if not args.cpu and args.n_GPUs > 1: - self.loss_module = nn.DataParallel( - self.loss_module, range(args.n_GPUs) - ) - - if args.load != '': self.load(ckp.dir, cpu=args.cpu) - - def forward(self, sr, hr): - losses = [] - for i, l in enumerate(self.loss): - if l['function'] is not None: - loss = l['function'](sr, hr) - effective_loss = l['weight'] * loss - losses.append(effective_loss) - self.log[-1, i] += effective_loss.item() - elif l['type'] == 'DIS': - self.log[-1, i] += self.loss[i - 1]['function'].loss - - loss_sum = sum(losses) - if len(self.loss) > 1: - self.log[-1, -1] += loss_sum.item() - - return loss_sum - - def step(self): - for l in self.get_loss_module(): - if hasattr(l, 'scheduler'): - l.scheduler.step() - - def start_log(self): - self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) - - def end_log(self, n_batches): - self.log[-1].div_(n_batches) - - def display_loss(self, batch): - n_samples = batch + 1 - log = [] - for l, c in zip(self.loss, self.log[-1]): - log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) - - return ''.join(log) - - def plot_loss(self, apath, epoch): - axis = np.linspace(1, epoch, epoch) - for i, l in enumerate(self.loss): - label = '{} Loss'.format(l['type']) - fig = plt.figure() - plt.title(label) - plt.plot(axis, self.log[:, i].numpy(), label=label) - plt.legend() - plt.xlabel('Epochs') - plt.ylabel('Loss') - plt.grid(True) - plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) - plt.close(fig) - - def get_loss_module(self): - if self.n_GPUs == 1: - return self.loss_module - else: - return self.loss_module.module - - def save(self, apath): - torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) - torch.save(self.log, os.path.join(apath, 'loss_log.pt')) - - def load(self, apath, cpu=False): - if cpu: - kwargs = {'map_location': lambda storage, loc: storage} - else: - kwargs = {} - - self.load_state_dict(torch.load( - os.path.join(apath, 'loss.pt'), - **kwargs - )) - self.log = torch.load(os.path.join(apath, 'loss_log.pt')) - for l in self.loss_module: - if hasattr(l, 'scheduler'): - for _ in range(len(self.log)): l.scheduler.step() - +import os +from importlib import import_module + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Loss(nn.modules.loss._Loss): + def __init__(self, args, ckp): + super(Loss, self).__init__() + print('Preparing loss function:') + + self.n_GPUs = args.n_GPUs + self.loss = [] + self.loss_module = nn.ModuleList() + for loss in args.loss.split('+'): + weight, loss_type = loss.split('*') + if loss_type == 'MSE': + loss_function = nn.MSELoss() + elif loss_type == 'L1': + loss_function = nn.L1Loss() + elif loss_type.find('VGG') >= 0: + module = import_module('loss.vgg') + loss_function = getattr(module, 'VGG')( + loss_type[3:], + rgb_range=args.rgb_range + ) + elif loss_type.find('GAN') >= 0: + module = import_module('loss.adversarial') + loss_function = getattr(module, 'Adversarial')( + args, + loss_type + ) + + self.loss.append({ + 'type': loss_type, + 'weight': float(weight), + 'function': loss_function} + ) + if loss_type.find('GAN') >= 0: + self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) + + if len(self.loss) > 1: + self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) + + for l in self.loss: + if l['function'] is not None: + print('{:.3f} * {}'.format(l['weight'], l['type'])) + self.loss_module.append(l['function']) + + self.log = torch.Tensor() + + device = torch.device('cpu' if args.cpu else 'cuda') + self.loss_module.to(device) + if args.precision == 'half': self.loss_module.half() + if not args.cpu and args.n_GPUs > 1: + self.loss_module = nn.DataParallel( + self.loss_module, range(args.n_GPUs) + ) + + if args.load != '': self.load(ckp.dir, cpu=args.cpu) + + def forward(self, sr, hr): + losses = [] + for i, l in enumerate(self.loss): + if l['function'] is not None: + loss = l['function'](sr, hr) + effective_loss = l['weight'] * loss + losses.append(effective_loss) + self.log[-1, i] += effective_loss.item() + elif l['type'] == 'DIS': + self.log[-1, i] += self.loss[i - 1]['function'].loss + + loss_sum = sum(losses) + if len(self.loss) > 1: + self.log[-1, -1] += loss_sum.item() + + return loss_sum + + def step(self): + for l in self.get_loss_module(): + if hasattr(l, 'scheduler'): + l.scheduler.step() + + def start_log(self): + self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) + + def end_log(self, n_batches): + self.log[-1].div_(n_batches) + + def display_loss(self, batch): + n_samples = batch + 1 + log = [] + for l, c in zip(self.loss, self.log[-1]): + log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) + + return ''.join(log) + + def plot_loss(self, apath, epoch): + axis = np.linspace(1, epoch, epoch) + for i, l in enumerate(self.loss): + label = '{} Loss'.format(l['type']) + fig = plt.figure() + plt.title(label) + plt.plot(axis, self.log[:, i].numpy(), label=label) + plt.legend() + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.grid(True) + plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) + plt.close(fig) + + def get_loss_module(self): + if self.n_GPUs == 1: + return self.loss_module + else: + return self.loss_module.module + + def save(self, apath): + torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) + torch.save(self.log, os.path.join(apath, 'loss_log.pt')) + + def load(self, apath, cpu=False): + if cpu: + kwargs = {'map_location': lambda storage, loc: storage} + else: + kwargs = {} + + self.load_state_dict(torch.load( + os.path.join(apath, 'loss.pt'), + **kwargs + )) + self.log = torch.load(os.path.join(apath, 'loss_log.pt')) + for l in self.get_loss_module(): + if hasattr(l, 'scheduler'): + for _ in range(len(self.log)): l.scheduler.step() +