Skip to content
Snippets Groups Projects
videotester.py 2.23 KiB
Newer Older
  • Learn to ignore specific revisions
  • 영제 임's avatar
    영제 임 committed
    import os
    import math
    
    import utility
    from data import common
    
    import torch
    import cv2
    
    from tqdm import tqdm
    
    class VideoTester():
        def __init__(self, args, my_model, ckp):
            self.args = args
            self.scale = args.scale
    
            self.ckp = ckp
            self.model = my_model
    
            self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
    
        def test(self):
            torch.set_grad_enabled(False)
    
            self.ckp.write_log('\nEvaluation on video:')
            self.model.eval()
    
            timer_test = utility.timer()
            for idx_scale, scale in enumerate(self.scale):
                vidcap = cv2.VideoCapture(self.args.dir_demo)
                total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
                vidwri = cv2.VideoWriter(
                    self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
                    cv2.VideoWriter_fourcc(*'XVID'),
                    vidcap.get(cv2.CAP_PROP_FPS),
                    (
                        int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                        int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    )
                )
    
                tqdm_test = tqdm(range(total_frames), ncols=80)
                for _ in tqdm_test:
                    success, lr = vidcap.read()
                    if not success: break
    
                    lr, = common.set_channel(lr, n_channels=self.args.n_colors)
                    lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
                    lr, = self.prepare(lr.unsqueeze(0))
                    sr = self.model(lr, idx_scale)
                    sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
    
                    normalized = sr * 255 / self.args.rgb_range
                    ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
                    vidwri.write(ndarr)
    
                vidcap.release()
                vidwri.release()
    
            self.ckp.write_log(
                'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
            )
            torch.set_grad_enabled(True)
    
        def prepare(self, *args):
            device = torch.device('cpu' if self.args.cpu else 'cuda')
            def _prepare(tensor):
                if self.args.precision == 'half': tensor = tensor.half()
                return tensor.to(device)
    
            return [_prepare(a) for a in args]