From 65836384b05949bec2f2cfe0804a7da7dcf99774 Mon Sep 17 00:00:00 2001 From: im_yeong_jae <iyj0121@ajou.ac.kr> Date: Tue, 2 May 2023 21:40:16 +0900 Subject: [PATCH] start --- src/loss/at.py | 33 +++++++++++++++++++++++++++++++++ src/main.py | 11 ++++++++++- src/model/__init__.py | 6 +++--- src/model/edsr.py | 2 +- src/option.py | 6 +++--- src/trainer.py | 20 +++++++++++++------- src/utility.py | 4 ++-- 7 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 src/loss/at.py diff --git a/src/loss/at.py b/src/loss/at.py new file mode 100644 index 0000000..35965a2 --- /dev/null +++ b/src/loss/at.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +import torch +import torch.nn as nn +import torch.nn.functional as F + + +''' +AT with sum of absolute values with power p +''' +class AT(nn.Module): + ''' + Paying More Attention to Attention: Improving the Performance of Convolutional + Neural Netkworks wia Attention Transfer + https://arxiv.org/pdf/1612.03928.pdf + ''' + def __init__(self, p): + super(AT, self).__init__() + self.p = p + + def forward(self, fm_s, fm_t): + loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t)) + + return loss + + def attention_map(self, fm, eps=1e-6): + am = torch.pow(torch.abs(fm), self.p) + am = torch.sum(am, dim=1, keepdim=True) + #norm = torch.norm(am, keepdim=True)#, dim=(2,3) + #am = torch.div(am, norm+eps) + + return am \ No newline at end of file diff --git a/src/main.py b/src/main.py index dbfac3e..d592f12 100644 --- a/src/main.py +++ b/src/main.py @@ -7,9 +7,17 @@ import loss from option import args from trainer import Trainer +from loss import at + torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) +check = utility.checkpoint(args) + +teacher_model = model.Model(args, check) +teacher_model.load_state_dict(torch.load('/home/iyj0121/EDSR-PyTorch/experiment/EDSR_x2.pt'), strict=False) +teacher_model.eval() + def main(): global model if args.data_test == ['video']: @@ -22,7 +30,8 @@ def main(): loader = data.Data(args) _model = model.Model(args, checkpoint) _loss = loss.Loss(args, checkpoint) if not args.test_only else None - t = Trainer(args, loader, _model, _loss, checkpoint) + kd_loss = at.AT(p=2.0) + t = Trainer(args, loader, _model, _loss, checkpoint, teacher_model, kd_loss) while not t.terminate(): t.train() t.test() diff --git a/src/model/__init__.py b/src/model/__init__.py index 2ffc49d..6d4220b 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -21,9 +21,9 @@ class Model(nn.Module): if self.cpu: self.device = torch.device('cpu') else: - if torch.backends.mps.is_available(): - self.device = torch.device('mps') - elif torch.cuda.is_available(): + #if torch.backends.mps.is_available(): + # self.device = torch.device('mps') + if torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') diff --git a/src/model/edsr.py b/src/model/edsr.py index ef4ffb1..74c6e61 100644 --- a/src/model/edsr.py +++ b/src/model/edsr.py @@ -62,7 +62,7 @@ class EDSR(nn.Module): x = self.tail(res) x = self.add_mean(x) - return x + return res, x def load_state_dict(self, state_dict, strict=True): own_state = self.state_dict() diff --git a/src/option.py b/src/option.py index 8ec9634..097343a 100644 --- a/src/option.py +++ b/src/option.py @@ -19,7 +19,7 @@ parser.add_argument('--seed', type=int, default=1, help='random seed') # Data specifications -parser.add_argument('--dir_data', type=str, default='../../../dataset', +parser.add_argument('--dir_data', type=str, default='/home/iyj0121/EDSR-PyTorch/dataset', help='dataset directory') parser.add_argument('--dir_demo', type=str, default='../test', help='demo image directory') @@ -87,9 +87,9 @@ parser.add_argument('--reset', action='store_true', help='reset the training') parser.add_argument('--test_every', type=int, default=1000, help='do test per every N batches') -parser.add_argument('--epochs', type=int, default=300, +parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') -parser.add_argument('--batch_size', type=int, default=16, +parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training') parser.add_argument('--split_batch', type=int, default=1, help='split the batch into smaller chunks') diff --git a/src/trainer.py b/src/trainer.py index 1a6f8cf..07d31a1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -8,8 +8,10 @@ import torch import torch.nn.utils as utils from tqdm import tqdm +from loss import at + class Trainer(): - def __init__(self, args, loader, my_model, my_loss, ckp): + def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss): self.args = args self.scale = args.scale @@ -18,6 +20,8 @@ class Trainer(): self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss + self.KD_loss = kd_loss + self.t_model = teacher_model self.optimizer = utility.make_optimizer(args, self.model) if self.args.load != '': @@ -45,8 +49,10 @@ class Trainer(): timer_model.tic() self.optimizer.zero_grad() - sr = self.model(lr, 0) - loss = self.loss(sr, hr) + res, sr = self.model(lr, 0) + t_res, _ = self.t_model(lr, 0) + kd_loss = self.KD_loss(res, t_res) + loss = self.loss(sr, hr) + 0.1*kd_loss loss.backward() if self.args.gclip > 0: utils.clip_grad_value_( @@ -88,7 +94,7 @@ class Trainer(): d.dataset.set_scale(idx_scale) for lr, hr, filename in tqdm(d, ncols=80): lr, hr = self.prepare(lr, hr) - sr = self.model(lr, idx_scale) + _, sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range) save_list = [sr] @@ -132,9 +138,9 @@ class Trainer(): if self.args.cpu: device = torch.device('cpu') else: - if torch.backends.mps.is_available(): - device = torch.device('mps') - elif torch.cuda.is_available(): + #if torch.backends.mps.is_available(): + # device = torch.device('mps') + if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') diff --git a/src/utility.py b/src/utility.py index 8eb6f5e..0860c49 100644 --- a/src/utility.py +++ b/src/utility.py @@ -92,9 +92,9 @@ class checkpoint(): def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) - trainer.loss.plot_loss(self.dir, epoch) + #trainer.loss.plot_loss(self.dir, epoch) - self.plot_psnr(epoch) + #self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt')) -- GitLab