Skip to content
Snippets Groups Projects
Commit 4c11b5d3 authored by Sanhyun Son's avatar Sanhyun Son
Browse files

Changed DIV2K class to new one

    Now DIV2K inherits SRData class.
    One can easily apply arbitrary super-resolution dataset
    by specifying directory rules.
parent 986add60
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
import os import os
import random
import math
from data import common from data import common
from data import SRData
import numpy as np import numpy as np
import scipy.misc as misc import scipy.misc as misc
...@@ -10,125 +9,44 @@ import scipy.misc as misc ...@@ -10,125 +9,44 @@ import scipy.misc as misc
import torch import torch
import torch.utils.data as data import torch.utils.data as data
class DIV2K(data.Dataset): class DIV2K(SRData.SRData):
def __init__(self, args, train=True): def __init__(self, args, train=True):
self._init_basic(args, train) super(DIV2K, self).__init__(args, train)
split = 'train' def _set_filesystem(self, dir_data):
dir_HR = 'DIV2K_{}_HR'.format(split) self.apath = dir_data + '/DIV2K'
dir_LR = 'DIV2K_{}_LR_bicubic'.format(split) self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
x_scale = ['X{}'.format(s) for s in args.scale] self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
self.ext = '.png'
if self.args.ext != 'pack': def _make_filename(self, idx):
self.dir_in = [ return '{:0>4}'.format(idx)
os.path.join(self.apath, dir_LR, xs) for xs in x_scale]
self.dir_tar = os.path.join(self.apath, dir_HR)
else:
print('Preparing binary packages...')
packname = 'pack.pt' if self.train else 'packv.pt'
name_tar = os.path.join(self.apath, dir_HR, packname)
print('\tLoading ' + name_tar)
self.pack_in = []
self.pack_tar = torch.load(name_tar)
if self.train:
self._save_partition(
self.pack_tar,
os.path.join(self.apath, dir_HR, 'packv.pt'))
for i, xs in enumerate(x_scale): def _name_hrfile(self, filename):
name_in = os.path.join(self.apath, dir_LR, xs, packname) return os.path.join(self.dir_hr, filename + self.ext)
print('\tLoading ' + name_in)
self.pack_in.append(torch.load(name_in)) def _name_hrbin(self):
if self.train: return os.path.join(self.apath, '{}_bin_HR.npy'.format(self.split))
self._save_partition(
self.pack_in[i],
os.path.join(self.apath, dir_LR, xs, 'packv.pt'))
def __getitem__(self, idx): def _name_lrfile(self, filename, scale):
scale = self.scale[self.idx_scale] return os.path.join(
idx = self._get_index(idx) self.dir_lr,
img_in, img_tar = self._load_file(idx) 'X{}/{}x{}{}'.format(scale, filename, scale, self.ext))
img_in, img_tar, pi, ai = self._get_patch(img_in, img_tar)
img_in, img_tar = common.set_channel(
img_in, img_tar, self.args.n_colors)
return common.np2Tensor(img_in, img_tar, self.args.rgb_range) def _name_lrbin(self, scale):
return os.path.join(
self.apath,
'{}_bin_LR_X{}.npy'.format(self.split, scale))
def __len__(self): def __len__(self):
if self.train: if self.train:
return self.args.n_train * self.repeat return len(self.images_hr) * self.repeat // self.args.superfetch
else: else:
return self.args.n_val return len(self.images_lr)
def _init_basic(self, args, train):
self.args = args
self.train = train
self.scale = args.scale
self.idx_scale = 0
self.repeat = args.test_every // (args.n_train // args.batch_size)
if args.ext == 'png':
self.apath = args.dir_data + '/DIV2K'
self.ext = '.png'
else:
self.apath = args.dir_data + '/DIV2K_decoded'
self.ext = '.pt'
def _get_index(self, idx): def _get_index(self, idx):
if self.train: if self.train:
idx = (idx % self.args.n_train) + 1 return idx % len(self.images_hr)
else: else:
idx = (idx + self.args.offset_val) + 1
return idx return idx
def _load_file(self, idx):
def _get_filename():
filename = '{:0>4}'.format(idx)
name_in = '{}/{}x{}{}'.format(
self.dir_in[self.idx_scale],
filename,
self.scale[self.idx_scale],
self.ext)
name_tar = os.path.join(self.dir_tar, filename + self.ext)
return name_in, name_tar
if self.args.ext == 'png':
name_in, name_tar = _get_filename()
img_in = misc.imread(name_in)
img_tar = misc.imread(name_tar)
elif self.args.ext == 'pt':
name_in, name_tar = _get_filename()
img_in = torch.load(name_in).numpy()
img_tar = torch.load(name_tar).numpy()
elif self.args.ext == 'pack':
img_in = self.pack_in[self.idx_scale][idx].numpy()
img_tar = self.pack_tar[idx].numpy()
return img_in, img_tar
def _get_patch(self, img_in, img_tar):
scale = self.scale[self.idx_scale]
if self.train:
img_in, img_tar, pi = common.get_patch(
img_in, img_tar, self.args, scale)
img_in, img_tar, ai = common.augment(img_in, img_tar)
return img_in, img_tar, pi, ai
else:
ih, iw, c = img_in.shape
img_tar = img_tar[0:ih * scale, 0:iw * scale, :]
return img_in, img_tar, None, None
def _save_partition(self, dict_full, name):
dict_val = {}
for i in range(self.args.n_train, self.args.n_train + self.args.n_val):
dict_val[i + 1] = dict_full[i + 1]
torch.save(dict_val, name)
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
import os
from data import common
import numpy as np
import scipy.misc as misc
import torch
import torch.utils.data as data
class SRData(data.Dataset):
def __init__(self, args, train=True):
self.args = args
self.train = train
self.split = 'train' if train else 'test'
self.scale = args.scale
self.idx_scale = 0
self.repeat = args.test_every // (args.n_train // args.batch_size)
self._set_filesystem(args.dir_data)
def _scan():
list_hr = []
list_lr = [[] * len(self.scale)]
idx_begin = 0 if train else args.n_train
idx_end = args.n_train if train else args.offset_val + args.n_val
for i in range(idx_begin + 1, idx_end + 1):
filename = self._make_filename(i)
list_hr.append(self._name_hrfile(filename))
for si, s in enumerate(self.scale):
list_lr[si].append(self._name_lrfile(filename, s))
return list_hr, list_lr
def _load():
self.images_hr = np.load(self._name_hrbin())
self.images_lr = [
np.load(self._name_lrbin(s)) for s in self.scale]
if args.ext == 'img':
self.images_hr, self.images_lr = _scan()
elif args.ext.find('bin') >= 0:
try:
if args.ext.find('reset') >= 0:
raise IOError
print('Loading a binary file')
_load()
except:
print('Preparing a binary file')
list_hr, list_lr = _scan()
hr = [misc.imread(f) for f in list_hr]
np.save(self._name_hrbin(), hr)
del hr
for si, s in enumerate(self.scale):
lr_scale = [misc.imread(f) for f in list_lr[si]]
np.save(self._name_lrbin(s), lr_scale)
del lr_scale
_load()
else:
print('Please define data type')
def _set_filesystem(self, dir_data):
raise NotImplementedError
def _make_filename(self, idx):
raise NotImplementedError
def _name_hrfile(self, filename):
raise NotImplementedError
def _name_hrbin(self):
raise NotImplementedError
def _name_lrfile(self, filename, scale):
raise NotImplementedError
def _name_lrbin(self, scale):
raise NotImplementedError
def __getitem__(self, idx):
img_lr, img_hr = self._load_file(idx)
img_lr, img_hr = self._get_patch(img_lr, img_hr)
img_lr, img_hr = common.set_channel(
img_lr, img_hr, self.args.n_colors)
return common.np2Tensor(img_lr, img_hr, self.args.rgb_range)
def __len__(self):
if self.train:
return len(self.images_hr)
else:
return len(self.images_lr)
def _get_index(self, idx):
return idx
def _load_file(self, idx):
idx = self._get_index(idx)
img_lr = self.images_lr[self.idx_scale][idx]
img_hr = self.images_hr[idx]
if self.args.ext == 'img':
img_lr = misc.imread(img_lr)
img_hr = misc.imread(img_hr)
return img_lr, img_hr
def _get_patch(self, img_lr, img_hr):
patch_size = self.args.patch_size
scale = self.scale[self.idx_scale]
multi_scale = len(self.scale) > 1
if self.train:
img_lr, img_hr = common.get_patch(
img_lr, img_hr, patch_size, scale, multi_scale=multi_scale)
img_lr, img_hr = common.augment(img_lr, img_hr)
else:
ih, iw, c = img_lr.shape
img_hr = img_hr[0:ih * scale, 0:iw * scale, :]
return img_lr, img_hr
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
from importlib import import_module from importlib import import_module
import dataloader from dataloader import MSDataLoader
from torch.utils.data.dataloader import default_collate
class data: class data:
def __init__(self, args): def __init__(self, args):
...@@ -10,25 +11,33 @@ class data: ...@@ -10,25 +11,33 @@ class data:
self.module_train = import_module('data.' + self.args.data_train) self.module_train = import_module('data.' + self.args.data_train)
self.module_test = import_module('data.' + self.args.data_test) self.module_test = import_module('data.' + self.args.data_test)
kwargs = {}
if self.args.no_cuda:
kwargs['collate_fn'] = default_collate
kwargs['pin_memory'] = False
else:
kwargs['collate_fn'] = default_collate
kwargs['pin_memory'] = True
loader_train = None loader_train = None
if not self.args.test_only: if not self.args.test_only:
trainset = getattr( trainset = getattr(
self.module_train, self.args.data_train)(self.args) self.module_train, self.args.data_train)(self.args)
loader_train = dataloader.MSDataLoader( loader_train = MSDataLoader(
self.args, self.args,
trainset, trainset,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
shuffle=True, shuffle=True,
pin_memory=True) **kwargs)
testset = getattr(self.module_test, self.args.data_test)( testset = getattr(self.module_test, self.args.data_test)(
self.args, train=False) self.args, train=False)
loader_test = dataloader.MSDataLoader( loader_test = MSDataLoader(
self.args, self.args,
testset, testset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
pin_memory=True) **kwargs)
return loader_train, loader_test return loader_train, loader_test
...@@ -8,62 +8,57 @@ import skimage.transform as st ...@@ -8,62 +8,57 @@ import skimage.transform as st
import torch import torch
from torchvision import transforms from torchvision import transforms
def get_patch(img_in, img_tar, args, scale, ix=-1, iy=-1): def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False):
(ih, iw, c) = img_in.shape ih, iw, c = img_in.shape
(th, tw) = (scale * ih, scale * iw) th, tw = scale * ih, scale * iw
patch_mult = scale if len(args.scale) > 1 else 1 p = scale if multi_scale else 1
tp = patch_mult * args.patch_size tp = p * patch_size
ip = tp // scale ip = tp // scale
if ix == -1:
ix = random.randrange(0, iw - ip + 1) ix = random.randrange(0, iw - ip + 1)
if iy == -1:
iy = random.randrange(0, ih - ip + 1) iy = random.randrange(0, ih - ip + 1)
tx, ty = scale * ix, scale * iy
(tx, ty) = (scale * ix, scale * iy)
img_in = img_in[iy:iy + ip, ix:ix + ip, :] img_in = img_in[iy:iy + ip, ix:ix + ip, :]
img_tar = img_tar[ty:ty + tp, tx:tx + tp, :] img_tar = img_tar[ty:ty + tp, tx:tx + tp, :]
info_patch = {
'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp}
return img_in, img_tar, info_patch return img_in, img_tar
def set_channel(img_in, img_tar, n_channel): def set_channel(img_in, img_tar, n_channel):
(h, w, c) = img_tar.shape h, w, c = img_tar.shape
def _set_channel(img):
if n_channel == 1 and c == 3: if n_channel == 1 and c == 3:
img_in = np.expand_dims(sc.rgb2ycbcr(img_in)[:, :, 0], 2) img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
img_tar = np.expand_dims(sc.rgb2ycbcr(img_tar)[:, :, 0], 2)
elif n_channel == 3 and c == 1: elif n_channel == 3 and c == 1:
img_in = np.concatenate([img_in] * n_channel, 2) img = np.concatenate([img] * n_channel, 2)
img_tar = np.concatenate([img_tar] * n_channel, 2)
return img_in, img_tar return img
return _set_channel(img_in), _set_channel(img_tar)
def np2Tensor(img_in, img_tar, rgb_range): def np2Tensor(img_in, img_tar, rgb_range):
ts = (2, 0, 1) def _to_tensor(img):
img_mul = rgb_range / 255 np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
img_in = torch.Tensor(img_in.transpose(ts).astype(float)).mul_(img_mul) torch_tensor = torch.from_numpy(np_transpose).float()
img_tar = torch.Tensor(img_tar.transpose(ts).astype(float)).mul_(img_mul) torch_tensor.mul_(rgb_range / 255)
return img_in, img_tar return torch_tensor
return _to_tensor(img_in), _to_tensor(img_tar)
def augment(img_in, img_tar, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
return _augment(img_in), _augment(img_tar)
def augment(img_in, img_tar, flip_h=True, rot=True):
info_aug = {'flip_h': False, 'flip_v': False, 'trans': False}
if random.random() < 0.5 and flip_h:
img_in = img_in[:, ::-1, :]
img_tar = img_tar[:, ::-1, :]
info_aug['flip_h'] = True
if rot:
if random.random() < 0.5:
img_in = img_in[::-1, :, :]
img_tar = img_tar[::-1, :, :]
info_aug['flip_v'] = True
if random.random() < 0.5:
img_in = img_in.transpose(1, 0, 2)
img_tar = img_tar.transpose(1, 0, 2)
info_aug['trans'] = True
return img_in, img_tar, info_aug
...@@ -2,6 +2,7 @@ import sys ...@@ -2,6 +2,7 @@ import sys
import threading import threading
import queue import queue
import random import random
import collections
import torch import torch
import torch.multiprocessing as multiprocessing import torch.multiprocessing as multiprocessing
...@@ -116,13 +117,13 @@ class MSDataLoader(DataLoader): ...@@ -116,13 +117,13 @@ class MSDataLoader(DataLoader):
def __init__( def __init__(
self, args, dataset, batch_size=1, shuffle=False, self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None, sampler=None, batch_sampler=None,
pin_memory=False, drop_last=False, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None): timeout=0, worker_init_fn=None):
super(MSDataLoader, self).__init__( super(MSDataLoader, self).__init__(
dataset, batch_size=batch_size, shuffle=shuffle, dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, batch_sampler=batch_sampler, sampler=sampler, batch_sampler=batch_sampler,
num_workers=args.n_threads, collate_fn=default_collate, num_workers=args.n_threads, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn) timeout=timeout, worker_init_fn=worker_init_fn)
......
# EDSR baseline model (x2) # EDSR baseline model (x2)
#python main.py --model EDSR --scale 2 --save EDSR_baseline_x2 --reset python main.py --model EDSR --scale 2 --save EDSR_baseline_x2 --reset --print_every 1
# EDSR baseline model (x3) - requires pre-trained EDSR baseline x2 model # EDSR baseline model (x3) - requires pre-trained EDSR baseline x2 model
#python main.py --model EDSR --scale 3 --save EDSR_baseline_x3 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt #python main.py --model EDSR --scale 3 --save EDSR_baseline_x3 --reset --pre_train ../experiment/model/EDSR_baseline_x2.pt
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR --n_resblocks 80 --reset #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR --n_resblocks 80 --reset
# Test your own images # Test your own images
python main.py --scale 4 --data_test MyImage --test_only --save_results --pre_train ../experiment/model/EDSR_baseline_x4.pt --chop_forward #python main.py --scale 4 --data_test MyImage --test_only --save_results --pre_train ../experiment/model/EDSR_baseline_x4.pt --chop_forward
#Advanced - JPEG artifact removal #Advanced - JPEG artifact removal
#python main.py --template MDSR_jpeg --model MDSR --scale 2+3+4 --save MDSR_jpeg --quality 75+ --reset #python main.py --template MDSR_jpeg --model MDSR --scale 2+3+4 --save MDSR_jpeg --quality 75+ --reset
...@@ -31,7 +31,7 @@ parser.add_argument('--n_val', type=int, default=10, ...@@ -31,7 +31,7 @@ parser.add_argument('--n_val', type=int, default=10,
help='number of validation set') help='number of validation set')
parser.add_argument('--offset_val', type=int, default=800, parser.add_argument('--offset_val', type=int, default=800,
help='validation index offest') help='validation index offest')
parser.add_argument('--ext', type=str, default='pack', parser.add_argument('--ext', type=str, default='bin',
help='dataset file extension') help='dataset file extension')
parser.add_argument('--scale', default='4', parser.add_argument('--scale', default='4',
help='super resolution scale') help='super resolution scale')
...@@ -45,6 +45,8 @@ parser.add_argument('--quality', type=str, default='', ...@@ -45,6 +45,8 @@ parser.add_argument('--quality', type=str, default='',
help='jpeg compression quality') help='jpeg compression quality')
parser.add_argument('--chop_forward', action='store_true', parser.add_argument('--chop_forward', action='store_true',
help='enable memory-efficient forward') help='enable memory-efficient forward')
parser.add_argument('--superfetch', type=int, default=1,
help='fetch multiple batches at onece')
# Model specifications # Model specifications
parser.add_argument('--model', default='EDSR', parser.add_argument('--model', default='EDSR',
...@@ -133,6 +135,10 @@ for i, q in enumerate(args.quality): ...@@ -133,6 +135,10 @@ for i, q in enumerate(args.quality):
if args.epochs == 0: if args.epochs == 0:
args.epochs = 1e8 args.epochs = 1e8
if args.superfetch > 1:
args.batch_size *= args.superfetch
args.print_every //= args.superfetch
for arg in vars(args): for arg in vars(args):
if vars(args)[arg] == 'True': if vars(args)[arg] == 'True':
vars(args)[arg] = True vars(args)[arg] = True
......
...@@ -47,14 +47,16 @@ class Trainer(): ...@@ -47,14 +47,16 @@ class Trainer():
timer_data, timer_model = utils.timer(), utils.timer() timer_data, timer_model = utils.timer(), utils.timer()
for batch, (input, target, idx_scale) in enumerate(self.loader_train): for batch, (input, target, idx_scale) in enumerate(self.loader_train):
input, target = self._prepare(input, target) input, target = self._prepare(input, target)
chunks_input = input.chunk(self.args.superfetch, dim=0)
chunks_target = target.chunk(self.args.superfetch, dim=0)
self._scale_change(idx_scale) self._scale_change(idx_scale)
timer_data.hold() timer_data.hold()
timer_model.tic() timer_model.tic()
for ci, ct in zip(chunks_input, chunks_target):
self.optimizer.zero_grad() self.optimizer.zero_grad()
output = self.model(input) output = self.model(ci)
loss = self._calc_loss(output, target) loss = self._calc_loss(output, ct)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
...@@ -152,8 +154,9 @@ class Trainer(): ...@@ -152,8 +154,9 @@ class Trainer():
return loss_total return loss_total
def _display_loss(self, batch): def _display_loss(self, batch):
n_samples = self.args.superfetch * (batch + 1)
log = [ log = [
'[{}: {:.4f}] '.format(t['type'], l / (batch + 1)) \ '[{}: {:.4f}] '.format(t['type'], l / n_samples) \
for l, t in zip(self.ckp.log_training[-1], self.loss)] for l, t in zip(self.ckp.log_training[-1], self.loss)]
return ''.join(log) return ''.join(log)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment