Skip to content
Snippets Groups Projects
Commit b1df88b5 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

Merge branch 'data/video' of https://github.com/thstkdgus35/EDSR-PyTorch into data/video

parents 330499d4 6a6ca3fb
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -79,6 +79,7 @@ class Trainer(): ...@@ -79,6 +79,7 @@ class Trainer():
self.model.eval() self.model.eval()
timer_test = utility.timer() timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
with torch.no_grad(): with torch.no_grad():
for idx_scale, scale in enumerate(self.scale): for idx_scale, scale in enumerate(self.scale):
eval_acc = 0 eval_acc = 0
...@@ -87,11 +88,8 @@ class Trainer(): ...@@ -87,11 +88,8 @@ class Trainer():
for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test): for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test):
filename = filename[0] filename = filename[0]
no_eval = (hr.nelement() == 1) no_eval = (hr.nelement() == 1)
if not no_eval:
lr, hr = self.prepare(lr, hr)
else:
lr, = self.prepare(lr)
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale) sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range) sr = utility.quantize(sr, self.args.rgb_range)
...@@ -119,10 +117,16 @@ class Trainer(): ...@@ -119,10 +117,16 @@ class Trainer():
) )
self.ckp.write_log( self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 'Forward time: {:.2f}s\n'.format(timer_test.toc())
) )
self.ckp.write_log('Saving...')
if self.args.save_results: self.ckp.end_background()
if not self.args.test_only: if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
self.ckp.write_log(
'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
def prepare(self, *args): def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda') device = torch.device('cpu' if self.args.cpu else 'cuda')
......
...@@ -2,14 +2,15 @@ import os ...@@ -2,14 +2,15 @@ import os
import math import math
import time import time
import datetime import datetime
from functools import reduce from multiprocessing import Process
from multiprocessing import Queue
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import scipy.misc as misc import imageio
import torch import torch
import torch.optim as optim import torch.optim as optim
...@@ -23,8 +24,10 @@ class timer(): ...@@ -23,8 +24,10 @@ class timer():
def tic(self): def tic(self):
self.t0 = time.time() self.t0 = time.time()
def toc(self): def toc(self, restart=False):
return time.time() - self.t0 diff = time.time() - self.t0
if restart: self.t0 = time.time()
return diff
def hold(self): def hold(self):
self.acc += self.toc() self.acc += self.toc()
...@@ -75,6 +78,8 @@ class checkpoint(): ...@@ -75,6 +78,8 @@ class checkpoint():
f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n') f.write('\n')
self.n_processes = 8
def get_path(self, *subdir): def get_path(self, *subdir):
return os.path.join(self.dir, *subdir) return os.path.join(self.dir, *subdir)
...@@ -121,13 +126,35 @@ class checkpoint(): ...@@ -121,13 +126,35 @@ class checkpoint():
plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test))) plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_test)))
plt.close(fig) plt.close(fig)
def begin_background(self):
self.queue = Queue()
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
self.process = [
Process(target=bg_target, args=(self.queue,)) \
for _ in range(self.n_processes)
]
for p in self.process: p.start()
def end_background(self):
for _ in range(self.n_processes): self.queue.put((None, None))
while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join()
def save_results(self, filename, save_list, scale): def save_results(self, filename, save_list, scale):
filename = self.get_path('results', '{}_x{}_'.format(filename, scale)) filename = self.get_path('results', '{}_x{}_'.format(filename, scale))
postfix = ('SR', 'LR', 'HR') postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix): for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range) normalized = v[0].mul(255 / self.args.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
misc.imsave('{}{}.png'.format(filename, p), ndarr) self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
def quantize(img, rgb_range): def quantize(img, rgb_range):
pixel_range = 255 / rgb_range pixel_range = 255 / rgb_range
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment