diff --git a/src/trainer.py b/src/trainer.py index c610294c40fadd9eaefb139c0a33596d450ec4e3..baff03c4c12a58b468f56cc4ff69a5cf0b9ca01a 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -79,6 +79,7 @@ class Trainer(): self.model.eval() timer_test = utility.timer() + if self.args.save_results: self.ckp.begin_background() with torch.no_grad(): for idx_scale, scale in enumerate(self.scale): eval_acc = 0 @@ -87,11 +88,8 @@ class Trainer(): for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test): filename = filename[0] no_eval = (hr.nelement() == 1) - if not no_eval: - lr, hr = self.prepare(lr, hr) - else: - lr, = self.prepare(lr) + lr, hr = self.prepare(lr, hr) sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range) @@ -119,10 +117,16 @@ class Trainer(): ) self.ckp.write_log( - 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True + 'Forward time: {:.2f}s\n'.format(timer_test.toc()) ) + + self.ckp.write_log('Saving...') + if self.args.save_results: self.ckp.end_background() if not self.args.test_only: self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) + self.ckp.write_log( + 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True + ) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda') diff --git a/src/utility.py b/src/utility.py index b5033c65ecafb4a9ffdebc2bacd72d7869be3a00..1793e26434ba4326b4136c699835fadd25876e3e 100644 --- a/src/utility.py +++ b/src/utility.py @@ -2,14 +2,15 @@ import os import math import time import datetime -from functools import reduce +from multiprocessing import Process +from multiprocessing import Queue import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np -import scipy.misc as misc +import imageio import torch import torch.optim as optim @@ -23,8 +24,10 @@ class timer(): def tic(self): self.t0 = time.time() - def toc(self): - return time.time() - self.t0 + def toc(self, restart=False): + diff = time.time() - self.t0 + if restart: self.t0 = time.time() + return diff def hold(self): self.acc += self.toc() @@ -75,6 +78,8 @@ class checkpoint(): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') + self.n_processes = 8 + def get_path(self, *subdir): return os.path.join(self.dir, *subdir) @@ -121,13 +126,35 @@ class checkpoint(): plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test))) plt.close(fig) + def begin_background(self): + self.queue = Queue() + + def bg_target(queue): + while True: + if not queue.empty(): + filename, tensor = queue.get() + if filename is None: break + imageio.imwrite(filename, tensor.numpy()) + + self.process = [ + Process(target=bg_target, args=(self.queue,)) \ + for _ in range(self.n_processes) + ] + + for p in self.process: p.start() + + def end_background(self): + for _ in range(self.n_processes): self.queue.put((None, None)) + while not self.queue.empty(): time.sleep(1) + for p in self.process: p.join() + def save_results(self, filename, save_list, scale): filename = self.get_path('results', '{}_x{}_'.format(filename, scale)) postfix = ('SR', 'LR', 'HR') for v, p in zip(save_list, postfix): normalized = v[0].mul(255 / self.args.rgb_range) - ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() - misc.imsave('{}{}.png'.format(filename, p), ndarr) + tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() + self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) def quantize(img, rgb_range): pixel_range = 255 / rgb_range