From 330499d491867306771966868cbc13e51f90e5f0 Mon Sep 17 00:00:00 2001 From: Sanghyun Son <thstkdgus35@snu.ac.kr> Date: Tue, 2 Oct 2018 11:41:06 +0900 Subject: [PATCH] optimize and fix some bugs --- src/videotester.py | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/videotester.py b/src/videotester.py index 2732a0d..0d20fba 100644 --- a/src/videotester.py +++ b/src/videotester.py @@ -28,41 +28,33 @@ class VideoTester(): for idx_scale, scale in enumerate(self.scale): vidcap = cv2.VideoCapture(self.args.dir_demo) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) - total_batches = math.ceil(total_frames / self.args.batch_size) vidwri = cv2.VideoWriter( self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), cv2.VideoWriter_fourcc(*'XVID'), - int(vidcap.get(cv2.CAP_PROP_FPS)), + vidcap.get(cv2.CAP_PROP_FPS), ( int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) ) ) - tqdm_test = tqdm(range(total_batches), ncols=80) + tqdm_test = tqdm(range(total_frames), ncols=80) for _ in tqdm_test: - fs = [] - for _ in range(self.args.batch_size): - success, lr = vidcap.read() - if success: - fs.append(lr) - else: - break - - fs = common.set_channel(*fs, n_channels=self.args.n_colors) - fs = common.np2Tensor(*fs, rgb_range=self.args.rgb_range) - lr = torch.stack(fs, dim=0) - lr, = self.prepare(lr) + success, lr = vidcap.read() + if not success: break + + lr, = common.set_channel(lr, n_channels=self.args.n_colors) + lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) + lr, = self.prepare(lr.unsqueeze(0)) sr = self.model(lr, idx_scale) - sr = utility.quantize(sr, self.args.rgb_range) + sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) - for i in range(self.args.batch_size): - normalized = sr[i].mul(255 / self.args.rgb_range) - ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() - vidwri.write(ndarr) + normalized = sr * 255 / self.args.rgb_range + ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() + vidwri.write(ndarr) - self.vidcap.release() - self.vidwri.release() + vidcap.release() + vidwri.release() self.ckp.write_log( 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True -- GitLab