Skip to content
Snippets Groups Projects
utility.py 7.26 KiB
Newer Older
  • Learn to ignore specific revisions
  • 영제 임's avatar
    영제 임 committed
    import os
    import math
    import time
    import datetime
    from multiprocessing import Process
    from multiprocessing import Queue
    
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    
    import numpy as np
    import imageio
    
    import torch
    import torch.optim as optim
    import torch.optim.lr_scheduler as lrs
    
    class timer():
        def __init__(self):
            self.acc = 0
            self.tic()
    
        def tic(self):
            self.t0 = time.time()
    
        def toc(self, restart=False):
            diff = time.time() - self.t0
            if restart: self.t0 = time.time()
            return diff
    
        def hold(self):
            self.acc += self.toc()
    
        def release(self):
            ret = self.acc
            self.acc = 0
    
            return ret
    
        def reset(self):
            self.acc = 0
    
    def bg_target(queue):
        while True:
            if not queue.empty():
                filename, tensor = queue.get()
                if filename is None: break
                imageio.imwrite(filename, tensor.numpy())
    
    class checkpoint():
        def __init__(self, args):
            self.args = args
            self.ok = True
            self.log = torch.Tensor()
            now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    
            if not args.load:
                if not args.save:
                    args.save = now
                self.dir = os.path.join('..', 'experiment', args.save)
            else:
                self.dir = os.path.join('..', 'experiment', args.load)
                if os.path.exists(self.dir):
                    self.log = torch.load(self.get_path('psnr_log.pt'))
                    print('Continue from epoch {}...'.format(len(self.log)))
                else:
                    args.load = ''
    
            if args.reset:
                os.system('rm -rf ' + self.dir)
                args.load = ''
    
            os.makedirs(self.dir, exist_ok=True)
            os.makedirs(self.get_path('model'), exist_ok=True)
            for d in args.data_test:
                os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
    
            open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
            self.log_file = open(self.get_path('log.txt'), open_type)
            with open(self.get_path('config.txt'), open_type) as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')
    
            self.n_processes = 8
    
        def get_path(self, *subdir):
            return os.path.join(self.dir, *subdir)
    
        def save(self, trainer, epoch, is_best=False):
            trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
            trainer.loss.save(self.dir)
            trainer.loss.plot_loss(self.dir, epoch)
    
            self.plot_psnr(epoch)
            trainer.optimizer.save(self.dir)
            torch.save(self.log, self.get_path('psnr_log.pt'))
    
        def add_log(self, log):
            self.log = torch.cat([self.log, log])
    
        def write_log(self, log, refresh=False):
            print(log)
            self.log_file.write(log + '\n')
            if refresh:
                self.log_file.close()
                self.log_file = open(self.get_path('log.txt'), 'a')
    
        def done(self):
            self.log_file.close()
    
        def plot_psnr(self, epoch):
            axis = np.linspace(1, epoch, epoch)
            for idx_data, d in enumerate(self.args.data_test):
                label = 'SR on {}'.format(d)
                fig = plt.figure()
                plt.title(label)
                for idx_scale, scale in enumerate(self.args.scale):
                    plt.plot(
                        axis,
                        self.log[:, idx_data, idx_scale].numpy(),
                        label='Scale {}'.format(scale)
                    )
                plt.legend()
                plt.xlabel('Epochs')
                plt.ylabel('PSNR')
                plt.grid(True)
                plt.savefig(self.get_path('test_{}.pdf'.format(d)))
                plt.close(fig)
    
        
    
        def begin_background(self):
            self.queue = Queue()
    
            self.process = [
                Process(target=bg_target, args=(self.queue,)) \
                for _ in range(self.n_processes)
            ]
            
            for p in self.process: p.start()
    
        def end_background(self):
            for _ in range(self.n_processes): self.queue.put((None, None))
            while not self.queue.empty(): time.sleep(1)
            for p in self.process: p.join()
    
        def save_results(self, dataset, filename, save_list, scale):
            if self.args.save_results:
                filename = self.get_path(
                    'results-{}'.format(dataset.dataset.name),
                    '{}_x{}_'.format(filename, scale)
                )
    
                postfix = ('SR', 'LR', 'HR')
                for v, p in zip(save_list, postfix):
                    normalized = v[0].mul(255 / self.args.rgb_range)
                    tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
                    self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
    
    def quantize(img, rgb_range):
        pixel_range = 255 / rgb_range
        return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
    
    def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
        if hr.nelement() == 1: return 0
    
        diff = (sr - hr) / rgb_range
        if dataset and dataset.dataset.benchmark:
            shave = scale
            if diff.size(1) > 1:
                gray_coeffs = [65.738, 129.057, 25.064]
                convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
                diff = diff.mul(convert).sum(dim=1)
        else:
            shave = scale + 6
    
        valid = diff[..., shave:-shave, shave:-shave]
        mse = valid.pow(2).mean()
    
        return -10 * math.log10(mse)
    
    def make_optimizer(args, target):
        '''
            make optimizer and scheduler together
        '''
        # optimizer
        trainable = filter(lambda x: x.requires_grad, target.parameters())
        kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
    
        if args.optimizer == 'SGD':
            optimizer_class = optim.SGD
            kwargs_optimizer['momentum'] = args.momentum
        elif args.optimizer == 'ADAM':
            optimizer_class = optim.Adam
            kwargs_optimizer['betas'] = args.betas
            kwargs_optimizer['eps'] = args.epsilon
        elif args.optimizer == 'RMSprop':
            optimizer_class = optim.RMSprop
            kwargs_optimizer['eps'] = args.epsilon
    
        # scheduler
        milestones = list(map(lambda x: int(x), args.decay.split('-')))
        kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
        scheduler_class = lrs.MultiStepLR
    
        class CustomOptimizer(optimizer_class):
            def __init__(self, *args, **kwargs):
                super(CustomOptimizer, self).__init__(*args, **kwargs)
    
            def _register_scheduler(self, scheduler_class, **kwargs):
                self.scheduler = scheduler_class(self, **kwargs)
    
            def save(self, save_dir):
                torch.save(self.state_dict(), self.get_dir(save_dir))
    
            def load(self, load_dir, epoch=1):
                self.load_state_dict(torch.load(self.get_dir(load_dir)))
                if epoch > 1:
                    for _ in range(epoch): self.scheduler.step()
    
            def get_dir(self, dir_path):
                return os.path.join(dir_path, 'optimizer.pt')
    
            def schedule(self):
                self.scheduler.step()
    
            def get_lr(self):
                return self.scheduler.get_lr()[0]
    
            def get_last_epoch(self):
                return self.scheduler.last_epoch
        
        optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
        optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
        return optimizer