Skip to content
Snippets Groups Projects
Select Git revision
  • 60f90ea2b155503bb2124ce1a4a1d314130b43c4
  • main default protected
2 results

greetings.js

Blame
  • utility.py 6.43 KiB
    import os
    import math
    import time
    import datetime
    from multiprocessing import Process
    from multiprocessing import Queue
    
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    
    import numpy as np
    import imageio
    
    import torch
    import torch.optim as optim
    import torch.optim.lr_scheduler as lrs
    
    class timer():
        def __init__(self):
            self.acc = 0
            self.tic()
    
        def tic(self):
            self.t0 = time.time()
    
        def toc(self, restart=False):
            diff = time.time() - self.t0
            if restart: self.t0 = time.time()
            return diff
    
        def hold(self):
            self.acc += self.toc()
    
        def release(self):
            ret = self.acc
            self.acc = 0
    
            return ret
    
        def reset(self):
            self.acc = 0
    
    class checkpoint():
        def __init__(self, args):
            self.args = args
            self.ok = True
            self.log = torch.Tensor()
            now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    
            if args.load == '.':
                if args.save == '.': args.save = now
                self.dir = '../experiment/' + args.save
            else:
                self.dir = '../experiment/' + args.load
                if not os.path.exists(self.dir):
                    args.load = '.'
                else:
                    self.log = torch.load(self.dir + '/psnr_log.pt')
                    print('Continue from epoch {}...'.format(len(self.log)))
    
            if args.reset:
                os.system('rm -rf ' + self.dir)
                args.load = '.'
    
            def _make_dir(path):
                if not os.path.exists(path): os.makedirs(path)
    
            _make_dir(self.dir)
            _make_dir(self.get_path('model'))
            _make_dir(self.get_path('results'))
    
            open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
            self.log_file = open(self.get_path('log.txt'), open_type)
            with open(self.get_path('config.txt'), open_type) as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')
    
            self.n_processes = 8
    
        def get_path(self, *subdir):
            return os.path.join(self.dir, *subdir)
    
        def save(self, trainer, epoch, is_best=False):
            trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
            trainer.loss.save(self.dir)
            trainer.loss.plot_loss(self.dir, epoch)
    
            self.plot_psnr(epoch)
            torch.save(self.log, self.get_path('psnr_log.pt'))
            torch.save(
                trainer.optimizer.state_dict(),
                self.get_path('optimizer.pt')
            )
    
        def add_log(self, log):
            self.log = torch.cat([self.log, log])
    
        def write_log(self, log, refresh=False):
            print(log)
            self.log_file.write(log + '\n')
            if refresh:
                self.log_file.close()
                self.log_file = open(self.get_path('log.txt'), 'a')
    
        def done(self):
            self.log_file.close()
    
        def plot_psnr(self, epoch):
            axis = np.linspace(1, epoch, epoch)
            label = 'SR on {}'.format(self.args.data_test)
            fig = plt.figure()
            plt.title(label)
            for idx_scale, scale in enumerate(self.args.scale):
                plt.plot(
                    axis,
                    self.log[:, idx_scale].numpy(),
                    label='Scale {}'.format(scale)
                )
            plt.legend()
            plt.xlabel('Epochs')
            plt.ylabel('PSNR')
            plt.grid(True)
            plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test)))
            plt.close(fig)
    
        def begin_background(self):
            self.queue = Queue()
    
            def bg_target(queue):
                while True:
                    if not queue.empty():
                        filename, tensor = queue.get()
                        if filename is None: break
                        imageio.imwrite(filename, tensor.numpy())
            
            self.process = [
                Process(target=bg_target, args=(self.queue,)) \
                for _ in range(self.n_processes)
            ]
            
            for p in self.process: p.start()
    
        def end_background(self):
            for _ in range(self.n_processes): self.queue.put((None, None))
            while not self.queue.empty(): time.sleep(1)
            for p in self.process: p.join()
    
        def save_results(self, filename, save_list, scale):
            filename = self.get_path('results', '{}_x{}_'.format(filename, scale))
            postfix = ('SR', 'LR', 'HR')
            for v, p in zip(save_list, postfix):
                normalized = v[0].mul(255 / self.args.rgb_range)
                tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
                self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
    
    def quantize(img, rgb_range):
        pixel_range = 255 / rgb_range
        return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
    
    def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
        diff = (sr - hr).data.div(rgb_range)
        if benchmark:
            shave = scale
            if diff.size(1) > 1:
                convert = diff.new(1, 3, 1, 1)
                convert[0, 0, 0, 0] = 65.738
                convert[0, 1, 0, 0] = 129.057
                convert[0, 2, 0, 0] = 25.064
                diff.mul_(convert).div_(256)
                diff = diff.sum(dim=1, keepdim=True)
        else:
            shave = scale + 6
    
        valid = diff[:, :, shave:-shave, shave:-shave]
        mse = valid.pow(2).mean()
    
        return -10 * math.log10(mse)
    
    def make_optimizer(args, my_model):
        trainable = filter(lambda x: x.requires_grad, my_model.parameters())
    
        if args.optimizer == 'SGD':
            optimizer_function = optim.SGD
            kwargs = {'momentum': args.momentum}
        elif args.optimizer == 'ADAM':
            optimizer_function = optim.Adam
            kwargs = {
                'betas': (args.beta1, args.beta2),
                'eps': args.epsilon
            }
        elif args.optimizer == 'RMSprop':
            optimizer_function = optim.RMSprop
            kwargs = {'eps': args.epsilon}
    
        kwargs['lr'] = args.lr
        kwargs['weight_decay'] = args.weight_decay
        
        return optimizer_function(trainable, **kwargs)
    
    def make_scheduler(args, my_optimizer):
        if args.decay_type == 'step':
            scheduler = lrs.StepLR(
                my_optimizer,
                step_size=args.lr_decay,
                gamma=args.gamma
            )
        elif args.decay_type.find('step') >= 0:
            milestones = args.decay_type.split('_')
            milestones.pop(0)
            milestones = list(map(lambda x: int(x), milestones))
            scheduler = lrs.MultiStepLR(
                my_optimizer,
                milestones=milestones,
                gamma=args.gamma
            )
    
        return scheduler