Skip to content
Snippets Groups Projects
Commit 5f6b0ff0 authored by Sanghyun SON's avatar Sanghyun SON
Browse files

Removed unnecessary parts

	Training scripts work well
	EDSR (with scaleRes) included
parent 93794061
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
......@@ -14,13 +14,21 @@ If you find our work useful in your research or publication, please cite our wor
year = {2017}
}
```
For some reason, we only provide demo script here.
This repository provides some demo codes for reproducing all the results from the paper. (Include training scripts.)
However, you can train our model with this code if you make the script and prepare the dataset on your own.
Also, pre-trained model will be uploaded soon.
**Differences with Torch version**
* Codes are much more compact. (Removed all unnecessary parts.)
* Model sizes are smaller. (About half!)
* Training requires less memory. (So that we can further increase the model size.)
* Test is faster.
* Python-based. (Unfortunately, this code do not follow python coding convention. Sorry for that.)
## Dependencies
* Python (Tested with 3.6)
* PyTorch (**Supports 0.20 ONLY**)
* PyTorch (**Supports 0.2.0 ONLY**)
## Code
Clone this repository into any place you want.
......@@ -29,7 +37,7 @@ git clone https://github.com/thstkdgus35/EDSR-PyTorch
cd EDSR-PyTorch
```
## Quick Start (Demo)
## Quick start (Demo)
You can test our super-resolution algorithm with your own images.
Place your images in ```test``` folder. (like ```test/puppy.jpeg```)
......@@ -38,7 +46,7 @@ Then, run the provided script in ```code``` folder.
Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
```bash
cd code
cd code # You are now in */EDSR-PyTorch/code
sh demo.sh
```
......@@ -50,6 +58,21 @@ We provide 3 pre-trained models (baseline ONLY, not full version.) till now. You
| --- | --- | --- | --- |--- |--- |
| **EDSR**| 4 | EDSR_baseline_x4.pt | 16 | 64 | 1.5M |
| **MDSR**| 2 + 3 + 4 | MDSR_baseline.pt | 16 | 64 | 3.2M |
| **MDSR (JPEG)**| 2 + 3 + 4 | MDSR_baseline_jpeg.pt | 16 | 64 | 3.2M |
| **MDSR (JPEG)***| 2 + 3 + 4 | MDSR_baseline_jpeg.pt | 16 | 64 | 3.2M |
*MDSR (JPEG) even reduces the JPEG artifact in output images. However, its DIV2K validation performance is slightly lower than the original MDSR.
## How to train EDSR and MDSR
You have to prepare DIV2K dataset for the training.
**Detailed steps for preparing data will be updated in next time.**
Training scripts are also included in ``demo.sh``. By uncomment the appropriate line and execute the script, you can train EDSR and MDSR by yourself.
```bash
cd code # You are now in */EDSR-PyTorch/code
sh demo.sh
```
**EDSR requires huge amount of memory, even with PyTorch. Therefore, it is impossible to test EDSR (not baseline) in 12GB GPU for now. This problem will be fixed in next update.**
**Detailed guide about arguments will be updated in next time.**
\ No newline at end of file
......@@ -9,22 +9,20 @@ class data:
self.args = args
def getLoader(self):
if not self.args.testOnly:
trainSet = getattr(self.trainModule, self.args.trainData)(self.args)
trainLoader = dataloader.MSDataLoader(
self.args, trainSet, batch_size=self.args.batchSize,
shuffle=True, pin_memory=True)
else:
trainLoader = None
testSet = []
for m in self.testModule:
if m[1] == 'benchmark':
benchmarkList = ['Set5', 'Set14', 'B100', 'Urban100']
for b in benchmarkList:
testSet.append(getattr(m[0], m[1])(self.args, b))
else:
testSet.append(getattr(m[0], m[1])(self.args, train=False))
testSet = getattr(m[0], m[1])(self.args, train=False)
testLoader = [dataloader.MSDataLoader(
self.args, s, batch_size=1,
shuffle=False, pin_memory=True) for s in testSet]
testLoader = dataloader.MSDataLoader(
self.args, testSet, batch_size=1,
shuffle=False, pin_memory=True)
return (trainLoader, testLoader)
from __future__ import print_function
import os
import os.path
import random
import math
import errno
from data import common
import numpy as np
import skimage
import skimage.io as sio
import skimage.color as sc
import torch
import torch.utils.data as data
from torchvision import transforms
class benchmark(data.Dataset):
def __init__(self, args, setName):
self.args = args
self.name = setName
self.train = False
self.scale = args.scale
self.scaleIdx = 0
apath = args.dataDir + '/benchmark'
self.ext = '.png'
dirHR = 'benchmark_test_HR/' + setName
dirLR = 'benchmark_test_LR/' + setName
xScale = ['X{}'.format(s) for s in args.scale]
self.dirIn = [os.path.join(apath, dirLR, xs) for xs in xScale]
self.dirTar = os.path.join(apath, dirHR)
self.fileList = []
for f in os.listdir(self.dirTar):
if f.endswith(self.ext):
fileName, fileExt = os.path.splitext(f)
self.fileList.append(fileName)
def __getitem__(self, idx):
scale = self.scale[self.scaleIdx]
(nameIn, nameTar) = self.getFileName(idx, scale)
imgIn = sio.imread(nameIn)
imgTar = sio.imread(nameTar)
if len(imgIn.shape) == 2:
imgIn = np.expand_dims(imgIn, 2)
imgTar = np.expand_dims(imgTar, 2)
ih, iw, c = imgIn.shape
imgTar = imgTar[0:ih * scale, 0:iw * scale, :]
imgIn, imgTar = common.setChannel(imgIn, imgTar, self.args.nChannel)
return common.np2Tensor(imgIn, imgTar, self.args.rgbRange)
def __len__(self):
return len(self.fileList)
def getFileName(self, idx, scale):
fileName = self.fileList[idx]
nameIn = '{}x{}{}'.format(
fileName, self.scale[self.scaleIdx], self.ext)
nameIn = os.path.join(self.dirIn[self.scaleIdx], nameIn)
nameTar = fileName + self.ext
nameTar = os.path.join(self.dirTar, nameTar)
return nameIn, nameTar
#python main.py --trainData myImage --testData myImage --scale 2+3+4 --preTrained ../experiment/model/MDSR_baseline.pt --testOnly True --saveResults True --save test_MDSR_baseline --reset True
#python main.py --trainData myImage --testData myImage --scale 2+3+4 --preTrained ../experiment/model/MDSR_baseline_jpeg.pt --testOnly True --saveResults True --save test_MDSR_baseline_jpeg --reset True
#python main.py --trainData myImage --testData myImage --scale 4 --preTrained ../experiment/model/EDSR_baseline_x4.pt --testOnly True --saveResults True --save test_EDSR_x4 --reset True
# Demo code for training
# Training EDSR_baseline_x2
#python main.py --template EDSR --model EDSR --scale 2 --nFeat 64 --nResBlock 16 --patchSize 96 --load EDSR_baseline_x2 --reset True
# Training EDSR_baseline_x3
#python main.py --template EDSR --model EDSR --scale 3 --nFeat 64 --nResBlock 16 --patchSize 144 --load EDSR_baseline_x3
# Training EDSR_baseline_x4
#python main.py --template EDSR --model EDSR --scale 4 --nFeat 64 --nResBlock 16 --patchSize 192 --load EDSR_baseline_x4
# Training MDSR_baseline
#python main.py --template MDSR --model MDSR --scale 2+3+4 --patchSize 48 --nFeat 64 --nResBlock 16 --load MDSR_baseline --reset True
# Training EDSR_x2
#python main.py --template EDSR --model EDSR_scale --scale 2 --nFeat 256 --nResBlock 32 --patchSize 96 --load EDSR_x2 --reset True
# Training EDSR_x3
#python main.py --template EDSR --model EDSR_scale --scale 3 --nFeat 256 --nResBlock 32 --patchSize 144 --load EDSR_x3 --reset True
# Training EDSR_x4
#python main.py --template EDSR --model EDSR_scale --scale 4 --nFeat 256 --nResBlock 32 --patchSize 192 --load EDSR_x4 --reset True
# Training MDSR
#python main.py --template MDSR --model MDSR --scale 2+3+4 --patchSize 48 --nFeat 64 --nResBlock 80 --load MDSR --reset True
# Demo code for test (Examples)
# Test with MDSR_baseline
#python main.py --testData myImage --scale 2+3+4 --preTrained ../experiment/model/MDSR_baseline.pt --testOnly True --saveResults True --save test_MDSR_baseline --reset True
# Test with MDSR_baseline_jpeg
#python main.py --testData myImage --scale 2+3+4 --preTrained ../experiment/model/MDSR_baseline_jpeg.pt --testOnly True --saveResults True --save test_MDSR_baseline_jpeg --reset True
# Test with EDSR_x4
#python main.py --testData myImage --scale 4 --preTrained ../experiment/model/EDSR_baseline_x4.pt --testOnly True --saveResults True --save test_EDSR_x4 --reset True
......@@ -11,13 +11,10 @@ class EDSR(nn.Module):
self.args = args
subMul, addMul = -1 * args.subMean, 1 * args.subMean
# Submean layer
self.subMean = common.meanShift(
args.rgbRange,
(0.4488, 0.4371, 0.4040),
subMul)
(0.4488, 0.4371, 0.4040), -1 * args.subMean)
# Head convolution for feature extracting
self.headConv = common.conv3x3(args.nChannel, nFeat)
......@@ -36,8 +33,7 @@ class EDSR(nn.Module):
# Addmean layer
self.addMean = common.meanShift(
args.rgbRange,
(0.4488, 0.4371, 0.4040),
addMul)
(0.4488, 0.4371, 0.4040), 1 * args.subMean)
def forward(self, x):
x = self.subMean(x)
......
from model import common
from model import EDSR
import torch.nn as nn
class EDSR_scale(EDSR.EDSR):
def __init__(self, args):
super(EDSR_scale, self).__init__(args)
nResBlock = args.nResBlock
nFeat = args.nFeat
# Main branch
modules = [
common.ResBlock_scale(nFeat, scale=0.1) for _ in range(nResBlock)]
modules.append(common.conv3x3(nFeat, nFeat))
self.body = nn.Sequential(*modules)
......@@ -10,13 +10,10 @@ class MDSR(nn.Module):
nFeat = args.nFeat
self.args = args
subMul, addMul = -1 * args.subMean, 1 * args.subMean
# Submean layer
self.subMean = common.meanShift(
args.rgbRange,
(0.4488, 0.4371, 0.4040),
subMul)
(0.4488, 0.4371, 0.4040), -1 * args.subMean)
# Head convolution for feature extracting
self.headConv = common.conv3x3(args.nChannel, nFeat)
......@@ -42,8 +39,7 @@ class MDSR(nn.Module):
# Addmean layer
self.addMean = common.meanShift(
args.rgbRange,
(0.4488, 0.4371, 0.4040),
addMul)
(0.4488, 0.4371, 0.4040), 1 * args.subMean)
self.scaleIdx = 0
......
......@@ -76,6 +76,19 @@ class ResBlock(nn.Module):
return res
class ResBlock_scale(ResBlock):
def __init__(
self, nFeat, kernel_size=3, bn=False, act=nn.ReLU(True), scale=1):
super(ResBlock_scale, self).__init__(nFeat, kernel_size, bn, act)
self.scale = scale
def forward(self, x):
res = self.body(x)
res *= 0.1
res += x
return res
class upsampler(nn.Module):
def __init__(self, scale, nFeat, act=False):
super(upsampler, self).__init__()
......
......@@ -63,6 +63,10 @@ parser.add_argument('--subMean', default=False, metavar='TF',
help='subtract pixel mean from the input')
parser.add_argument('--precision', default='single', metavar='FP',
help='model and data precision')
parser.add_argument('--multiOutput', default=False, metavar='FP',
help='model generates multiple outputs')
parser.add_argument('--multiTarget', default=False, metavar='FP',
help='model requires multiple targets')
# Training specifications
parser.add_argument('--reset', default=False, metavar='TF',
......
def setTemplate(args):
# Set the templates here
if args.template == 'DIV2K':
if args.template == 'EDSR':
args.trainData = 'DIV2K'
args.testData = 'DIV2K'
args.epochs = 300
args.lrDecay = 200
elif args.template == 'MDSR':
args.trainData = 'DIV2K'
args.testData = 'DIV2K'
args.epochs = 650
args.lrDecay = 200
elif args.template == 'DIV2K_jpeg':
args.trainData = 'DIV2K_jpeg'
args.testData = 'DIV2K_jpeg'
args.epochs = 200
args.lrDecay = 100
elif args.template == 'MDSR':
args.trainData = 'DIV2K'
args.testData = 'DIV2K'
args.epochs = 650
args.lrDecay = 200
......@@ -22,7 +22,11 @@ class trainer():
def scaleChange(self, scaleIdx, testSet=None):
if len(self.scale) > 1:
if self.args.nGPUs == 1:
self.model.setScale(scaleIdx)
else:
self.model.module.setScale(scaleIdx)
if testSet is not None:
testSet.dataset.setScale(scaleIdx)
......@@ -67,19 +71,13 @@ class trainer():
testTimer = utils.timer()
self.checkpoint.addLog(
torch.zeros(
1,
len(self.args.task),
len(self.testLoader),
len(self.scale)), False)
for setIdx, testSet in enumerate(self.testLoader):
setName = testSet.dataset.name if testSet.dataset.name else 'Test set'
torch.zeros(1, len(self.scale)), False)
testTimer.tic()
for scaleIdx in range(len(self.scale)):
scale = self.scale[scaleIdx]
self.scaleChange(scaleIdx, testSet)
for imgIdx, (input, target, _) in enumerate(testSet):
self.scaleChange(scaleIdx, self.testLoader)
for imgIdx, (input, target, _) in enumerate(self.testLoader):
input, target = self.prepareData(input, target, volatile=True)
# Self ensemble!
......@@ -89,27 +87,24 @@ class trainer():
else:
output = self.model(input)
evalLog = self.evaluate(
self.args, input, output, target, locals())
evalValue = utils.calcPSNR(
output, target,
self.testLoader.dataset.name, self.args.rgbRange, scale)
self.checkpoint.testLog[-1, scaleIdx] \
+= evalValue / len(self.testLoader)
self.checkpoint.saveResults(imgIdx, input, output, target, scale)
if len(self.scale) > 1:
best = self.checkpoint.testLog.squeeze(0).max(0)
else:
best = self.checkpoint.testLog.max(0)
bestValue, bestEpoch = self.checkpoint.testLog.max(0)
performance = 'PSNR: {:.3f}'.format(
self.checkpoint.testLog[-1, scaleIdx])
self.checkpoint.saveLog(
'[SR on {} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
self.testLoader.dataset.name, scale, performance,
bestValue.squeeze(0)[scaleIdx],
bestEpoch.squeeze(0)[scaleIdx] + 1))
for taskIdx, task in enumerate(self.args.task):
performance = '{}: {:.3f}'.format(
evalLog[taskIdx],
self.checkpoint.testLog[-1, taskIdx, setIdx, scaleIdx])
self.checkpoint.saveLog(
'[{} on {} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
task, setName, scale,
performance,
best[0][taskIdx, setIdx, scaleIdx],
best[1][taskIdx, setIdx, scaleIdx] + 1))
self.checkpoint.saveLog('Time: {:.2f}s'.format(testTimer.toc()))
self.checkpoint.saveLog('', refresh=True)
'Time: {:.2f}s\n'.format(testTimer.toc()), refresh=True)
self.checkpoint.save(self, epoch)
def setLr(self):
......@@ -137,9 +132,6 @@ class trainer():
input = Variable(input.cuda(), volatile=volatile)
target = Variable(target.cuda())
if self.args.model == 'S2R':
self.model.setMask(mask)
if self.args.precision == 'half':
input = input.half()
target = target.half()
......@@ -177,39 +169,6 @@ class trainer():
return lossLog
def evaluate(self, args, input, output, target, etc):
def _doEval(output, target, task):
setName = etc['setName']
if task == 'SR':
return utils.calcPSNR(
output, target, setName,
self.args.rgbRange,
etc['scale']), 'PSNR'
else:
return 0, 'None'
evalLog = [None] * len(args.task)
for taskIdx, task in enumerate(self.args.task):
outputN = output
if etc['setName'] == 'myImage':
targetN = None
else:
targetN = target
evalValue, evalLog[taskIdx] = _doEval(outputN, targetN, task)
setIdx = etc['setIdx']
scaleIdx = etc['scaleIdx']
imgIdx = etc['imgIdx']
self.checkpoint.testLog[-1, taskIdx, setIdx, scaleIdx] \
+= evalValue / len(etc['testSet'])
self.checkpoint.saveResults(
setIdx, imgIdx,
input, outputN, targetN, task, self.scale[scaleIdx])
return evalLog
def terminate(self):
if self.args.testOnly:
self.test()
......@@ -220,3 +179,4 @@ class trainer():
return True
else:
return False
......@@ -184,11 +184,9 @@ class checkpoint():
plt.savefig('{}/loss_{}.pdf'.format(dir, loss['type']))
plt.close(fig)
for setIdx, testSet in enumerate(trainer.testLoader):
setName = testSet.dataset.name
for taskIdx, task in enumerate(self.args.task):
setName = trainer.testLoader.dataset.name
fig = plt.figure()
label = '{} on {}'.format(task, setName)
label = 'SR on {}'.format(setName)
plt.title(label)
plt.xlabel('Epochs')
plt.grid(True)
......@@ -196,24 +194,22 @@ class checkpoint():
legend = 'Scale {}'.format(scale)
plt.plot(
axis,
test[:, taskIdx, setIdx, scaleIdx].numpy(),
test[:, scaleIdx].numpy(),
label=legend)
plt.legend()
plt.savefig(
'{}/test_{}_{}.pdf'.format(dir, task, setName))
'{}/test_SR_{}.pdf'.format(dir, setName))
plt.close(fig)
def getEpoch(self):
return len(self.testLog) + 1
def saveResults(self, setIdx, idx, input, output, target, task, scale):
setIdx += 1
def saveResults(self, idx, input, output, target, scale):
idx += 1
if self.args.saveResults:
if task == 'SR':
fileName = '{}/results/{}-{}x{}_'.format(
self.dir, setIdx, idx, scale)
fileName = '{}/results/{}x{}_'.format(
self.dir, idx, scale)
tUtils.save_image(
input.data[0] / self.args.rgbRange, fileName + 'LR.png')
tUtils.save_image(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment