From e7d9d6b3ac4b3c826de1fe8a89ac950ece13eae0 Mon Sep 17 00:00:00 2001 From: im_yeong_jae <iyj0121@ajou.ac.kr> Date: Wed, 3 May 2023 22:51:34 +0900 Subject: [PATCH] teacher model change --- src/loss/at.py | 4 ++-- src/main.py | 1 - src/trainer.py | 4 +--- src/utility.py | 4 ++-- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/loss/at.py b/src/loss/at.py index 35965a2..0f7b5a0 100644 --- a/src/loss/at.py +++ b/src/loss/at.py @@ -27,7 +27,7 @@ class AT(nn.Module): def attention_map(self, fm, eps=1e-6): am = torch.pow(torch.abs(fm), self.p) am = torch.sum(am, dim=1, keepdim=True) - #norm = torch.norm(am, keepdim=True)#, dim=(2,3) - #am = torch.div(am, norm+eps) + norm = torch.norm(am, keepdim=True)#, dim=(2,3) + am = torch.div(am, norm+eps) return am \ No newline at end of file diff --git a/src/main.py b/src/main.py index 67721a3..f3a63d7 100644 --- a/src/main.py +++ b/src/main.py @@ -15,7 +15,6 @@ checkpoint = utility.checkpoint(args) check = utility.checkpoint(args) teacher_model = model.Model(args, check) -#teacher_model.load_state_dict(torch.load('/home/iyj0121/AT_EDSR/model_best.pt'), strict=False) teacher_model.load(apath='/home/iyj0121/AT_EDSR/') teacher_model.eval() diff --git a/src/trainer.py b/src/trainer.py index 8023045..632eea6 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -8,8 +8,6 @@ import torch import torch.nn.utils as utils from tqdm import tqdm -from loss import at - class Trainer(): def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss): self.args = args @@ -53,7 +51,7 @@ class Trainer(): with torch.no_grad(): t_res, _ = self.t_model(lr, 0) kd_loss = self.KD_loss(res, t_res) - loss = self.loss(sr, hr) + 0.0001*kd_loss + loss = self.loss(sr, hr) + 0.1*kd_loss loss.backward() if self.args.gclip > 0: utils.clip_grad_value_( diff --git a/src/utility.py b/src/utility.py index 0860c49..8eb6f5e 100644 --- a/src/utility.py +++ b/src/utility.py @@ -92,9 +92,9 @@ class checkpoint(): def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) - #trainer.loss.plot_loss(self.dir, epoch) + trainer.loss.plot_loss(self.dir, epoch) - #self.plot_psnr(epoch) + self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt')) -- GitLab