Skip to content
Snippets Groups Projects
Commit 4a0e0cae authored by 영제 임's avatar 영제 임
Browse files

first commit

parents
No related branches found
No related tags found
No related merge requests found
# This code is released under the CC BY-SA 4.0 license.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Squeeze and Excitation block
class SELayer(nn.Module):
def __init__(self, num_channels, reduction_ratio=8):
'''
num_channels: The number of input channels
reduction_ratio: The reduction ratio 'r' from the paper
'''
super(SELayer, self).__init__()
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor):
batch_size, num_channels, H, W = input_tensor.size()
squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
# channel excitation
fc_out_1 = self.relu(self.fc1(squeeze_tensor))
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
a, b = squeeze_tensor.size()
output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
return output_tensor
# SSPCAB implementation
class SSPCAB(nn.Module):
def __init__(self, channels, kernel_dim=1, dilation=1, reduction_ratio=8):
'''
channels: The number of filter at the output (usually the same with the number of filter from the input)
kernel_dim: The dimension of the sub-kernels ' k' ' from the paper
dilation: The dilation dimension 'd' from the paper
reduction_ratio: The reduction ratio for the SE block ('r' from the paper)
'''
super(SSPCAB, self).__init__()
self.pad = kernel_dim + dilation
self.border_input = kernel_dim + 2*dilation + 1
self.relu = nn.ReLU()
self.se = SELayer(channels, reduction_ratio=reduction_ratio)
self.conv1 = nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=kernel_dim)
self.conv2 = nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=kernel_dim)
self.conv3 = nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=kernel_dim)
self.conv4 = nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=kernel_dim)
def forward(self, x):
x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), "constant", 0)
x1 = self.conv1(x[:, :, :-self.border_input, :-self.border_input])
x2 = self.conv2(x[:, :, self.border_input:, :-self.border_input])
x3 = self.conv3(x[:, :, :-self.border_input, self.border_input:])
x4 = self.conv4(x[:, :, self.border_input:, self.border_input:])
x = self.relu(x1 + x2 + x3 + x4)
x = self.se(x)
return x
class MY_SSPCAB(nn.Module):
def __init__(self, channels, kernel_dim=1, dilation=1, reduction_ratio=8):
'''
channels: The number of filter at the output (usually the same with the number of filter from the input)
kernel_dim: The dimension of the sub-kernels ' k' ' from the paper
dilation: The dilation dimension 'd' from the paper
reduction_ratio: The reduction ratio for the SE block ('r' from the paper)
'''
super(MY_SSPCAB, self).__init__()
self.relu = nn.ReLU()
self.sspcab = SSPCAB(channels=64, kernel_dim=kernel_dim, dilation=dilation, reduction_ratio=reduction_ratio)
self.imput_conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=kernel_dim)
self.output_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=kernel_dim)
def forward(self, x):
x1 = self.imput_conv(x)
x2 = self.sspcab(x1)
x3 = self.output_conv(x2)
return x3
# Example of how our block should be updated
# mse_loss = nn.MSELoss()
# cost_sspcab = mse_loss(input_sspcab, output_sspcab)
from model import common
import torch.nn as nn
import torch.nn.init as init
url = {
'r20f64': ''
}
def make_model(args, parent=False):
return VDSR(args)
class VDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(VDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
def basic_block(in_channels, out_channels, act):
return common.BasicBlock(
conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=act
)
# define body module
m_body = []
m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
for _ in range(n_resblocks - 2):
m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
m_body.append(basic_block(n_feats, args.n_colors, None))
self.body = nn.Sequential(*m_body)
def forward(self, x):
x = self.sub_mean(x)
res = self.body(x)
res += x
x = self.add_mean(res)
return x
import argparse
import template
parser = argparse.ArgumentParser(description='EDSR and MDSR')
parser.add_argument('--debug', action='store_true',
help='Enables debug mode')
parser.add_argument('--template', default='.',
help='You can set various templates in option.py')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=6,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/home/iyj0121/EDSR-PyTorch/dataset',
help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../test',
help='demo image directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=192,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--chop', action='store_true',
help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='EDSR',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--n_resblocks', type=int, default=16,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
# Option for Residual dense network (RDN)
parser.add_argument('--G0', type=int, default=64,
help='default number of filters. (Use in RDN)')
parser.add_argument('--RDNkSize', type=int, default=3,
help='default kernel size. (Use in RDN)')
parser.add_argument('--RDNconfig', type=str, default='B',
help='parameters config of RDN. (Use in RDN)')
# Option for Residual channel attention network (RCAN)
parser.add_argument('--n_resgroups', type=int, default=10,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
# Training specifications
parser.add_argument('--reset', action='store_true',
help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=300,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
parser.add_argument('--gan_k', type=int, default=1,
help='k value for adversarial loss')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--decay', type=str, default='200',
help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,
help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8',
help='skipping batch that has large error')
# Log specifications
parser.add_argument('--save', type=str, default='test',
help='file name to save')
parser.add_argument('--load', type=str, default='',
help='file name to load')
parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=100,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',
help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
args = parser.parse_args()
template.set_template(args)
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
def set_template(args):
# Set the templates here
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.decay = '100'
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.decay = '500'
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.decay = '150'
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1
import os
import math
from decimal import Decimal
import utility
import torch
import torch.nn.utils as utils
from tqdm import tqdm
class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp, sub_model):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test
self.model = my_model
self.sub_sspcab = sub_model
self.loss = my_loss
self.optimizer = utility.make_optimizer(args, self.model)
if self.args.load != '':
self.optimizer.load(ckp.dir, epoch=len(ckp.log))
self.error_last = 1e8
def train(self):
self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr()
self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
)
self.loss.start_log()
self.model.train()
timer_data, timer_model = utility.timer(), utility.timer()
# TEMP
self.loader_train.dataset.set_scale(0)
for batch, (lr, hr, _,) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)
timer_data.hold()
timer_model.tic()
self.optimizer.zero_grad()
sub_sr = self.sub_sspcab(lr)
sub_loss = self.loss(sub_sr, lr)
sr = self.model(lr, 0)
loss = self.loss(sr, hr) + 0.1*sub_loss
loss.backward()
if self.args.gclip > 0:
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
self.optimizer.step()
timer_model.hold()
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
self.loss.display_loss(batch),
timer_model.release(),
timer_data.release()))
timer_data.tic()
self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
if self.args.save_gt:
save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
d.dataset.name,
scale,
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
if self.args.save_results:
self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
if self.args.cpu:
device = torch.device('cpu')
else:
#if torch.backends.mps.is_available():
# device = torch.device('mps')
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.optimizer.get_last_epoch() + 1
return epoch >= self.args.epochs
import os
import math
import time
import datetime
from multiprocessing import Process
from multiprocessing import Queue
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import imageio
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
class timer():
def __init__(self):
self.acc = 0
self.tic()
def tic(self):
self.t0 = time.time()
def toc(self, restart=False):
diff = time.time() - self.t0
if restart: self.t0 = time.time()
return diff
def hold(self):
self.acc += self.toc()
def release(self):
ret = self.acc
self.acc = 0
return ret
def reset(self):
self.acc = 0
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
class checkpoint():
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if not args.load:
if not args.save:
args.save = now
self.dir = os.path.join('..', 'experiment', args.save)
else:
self.dir = os.path.join('..', 'experiment', args.load)
if os.path.exists(self.dir):
self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log)))
else:
args.load = ''
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = ''
os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True)
for d in args.data_test:
os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type)
with open(self.get_path('config.txt'), open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
self.n_processes = 8
def get_path(self, *subdir):
return os.path.join(self.dir, *subdir)
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
trainer.optimizer.save(self.dir)
torch.save(self.log, self.get_path('psnr_log.pt'))
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.get_path('log.txt'), 'a')
def done(self):
self.log_file.close()
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
for idx_data, d in enumerate(self.args.data_test):
label = 'SR on {}'.format(d)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_data, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig)
def begin_background(self):
self.queue = Queue()
self.process = [
Process(target=bg_target, args=(self.queue,)) \
for _ in range(self.n_processes)
]
for p in self.process: p.start()
def end_background(self):
for _ in range(self.n_processes): self.queue.put((None, None))
while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join()
def save_results(self, dataset, filename, save_list, scale):
if self.args.save_results:
filename = self.get_path(
'results-{}'.format(dataset.dataset.name),
'{}_x{}_'.format(filename, scale)
)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range)
tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if hr.nelement() == 1: return 0
diff = (sr - hr) / rgb_range
if dataset and dataset.dataset.benchmark:
shave = scale
if diff.size(1) > 1:
gray_coeffs = [65.738, 129.057, 25.064]
convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
diff = diff.mul(convert).sum(dim=1)
else:
shave = scale + 6
valid = diff[..., shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
def make_optimizer(args, target):
'''
make optimizer and scheduler together
'''
# optimizer
trainable = filter(lambda x: x.requires_grad, target.parameters())
kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'SGD':
optimizer_class = optim.SGD
kwargs_optimizer['momentum'] = args.momentum
elif args.optimizer == 'ADAM':
optimizer_class = optim.Adam
kwargs_optimizer['betas'] = args.betas
kwargs_optimizer['eps'] = args.epsilon
elif args.optimizer == 'RMSprop':
optimizer_class = optim.RMSprop
kwargs_optimizer['eps'] = args.epsilon
# scheduler
milestones = list(map(lambda x: int(x), args.decay.split('-')))
kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
scheduler_class = lrs.MultiStepLR
class CustomOptimizer(optimizer_class):
def __init__(self, *args, **kwargs):
super(CustomOptimizer, self).__init__(*args, **kwargs)
def _register_scheduler(self, scheduler_class, **kwargs):
self.scheduler = scheduler_class(self, **kwargs)
def save(self, save_dir):
torch.save(self.state_dict(), self.get_dir(save_dir))
def load(self, load_dir, epoch=1):
self.load_state_dict(torch.load(self.get_dir(load_dir)))
if epoch > 1:
for _ in range(epoch): self.scheduler.step()
def get_dir(self, dir_path):
return os.path.join(dir_path, 'optimizer.pt')
def schedule(self):
self.scheduler.step()
def get_lr(self):
return self.scheduler.get_lr()[0]
def get_last_epoch(self):
return self.scheduler.last_epoch
optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
return optimizer
import os
import math
import utility
from data import common
import torch
import cv2
from tqdm import tqdm
class VideoTester():
def __init__(self, args, my_model, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.model = my_model
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self):
torch.set_grad_enabled(False)
self.ckp.write_log('\nEvaluation on video:')
self.model.eval()
timer_test = utility.timer()
for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
vidwri = cv2.VideoWriter(
self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
cv2.VideoWriter_fourcc(*'XVID'),
vidcap.get(cv2.CAP_PROP_FPS),
(
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
)
)
tqdm_test = tqdm(range(total_frames), ncols=80)
for _ in tqdm_test:
success, lr = vidcap.read()
if not success: break
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
lr, = self.prepare(lr.unsqueeze(0))
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
normalized = sr * 255 / self.args.rgb_range
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
vidwri.write(ndarr)
vidcap.release()
vidwri.release()
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment