Skip to content
Snippets Groups Projects
Commit 6a6ca3fb authored by Sanghyun Son's avatar Sanghyun Son
Browse files

multiprocessing for save_results

parent ab6b0ad7
No related tags found
1 merge request!1Jan 09, 2018 updates
......@@ -79,6 +79,7 @@ class Trainer():
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
with torch.no_grad():
for idx_scale, scale in enumerate(self.scale):
eval_acc = 0
......@@ -87,11 +88,8 @@ class Trainer():
for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
filename = filename[0]
no_eval = (hr.nelement() == 1)
if not no_eval:
lr, hr = self.prepare(lr, hr)
else:
lr, = self.prepare(lr)
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
......@@ -119,10 +117,16 @@ class Trainer():
)
self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
'Forward time: {:.2f}s\n'.format(timer_test.toc())
)
self.ckp.write_log('Saving...')
if self.args.save_results: self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
......
......@@ -2,14 +2,15 @@ import os
import math
import time
import datetime
from functools import reduce
from multiprocessing import Process
from multiprocessing import Queue
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc as misc
import imageio
import torch
import torch.optim as optim
......@@ -23,8 +24,10 @@ class timer():
def tic(self):
self.t0 = time.time()
def toc(self):
return time.time() - self.t0
def toc(self, restart=False):
diff = time.time() - self.t0
if restart: self.t0 = time.time()
return diff
def hold(self):
self.acc += self.toc()
......@@ -75,6 +78,8 @@ class checkpoint():
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
self.n_processes = 8
def get_path(self, *subdir):
return os.path.join(self.dir, *subdir)
......@@ -121,13 +126,35 @@ class checkpoint():
plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test)))
plt.close(fig)
def begin_background(self):
self.queue = Queue()
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
self.process = [
Process(target=bg_target, args=(self.queue,)) \
for _ in range(self.n_processes)
]
for p in self.process: p.start()
def end_background(self):
for _ in range(self.n_processes): self.queue.put((None, None))
while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join()
def save_results(self, filename, save_list, scale):
filename = self.get_path('results', '{}_x{}_'.format(filename, scale))
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
misc.imsave('{}{}.png'.format(filename, p), ndarr)
tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment