diff --git a/src/loss/at.py b/src/loss/at.py index 35965a2d7333c06c2fbd28c71968440d34272c84..0f7b5a01a4f8b63f6b125932f4eda91a707d2eb7 100644 --- a/src/loss/at.py +++ b/src/loss/at.py @@ -27,7 +27,7 @@ class AT(nn.Module): def attention_map(self, fm, eps=1e-6): am = torch.pow(torch.abs(fm), self.p) am = torch.sum(am, dim=1, keepdim=True) - #norm = torch.norm(am, keepdim=True)#, dim=(2,3) - #am = torch.div(am, norm+eps) + norm = torch.norm(am, keepdim=True)#, dim=(2,3) + am = torch.div(am, norm+eps) return am \ No newline at end of file diff --git a/src/main.py b/src/main.py index 67721a30de8876086f3e76ca05529e6b6e23faaa..f3a63d74a4e0349485edabe5620cfcd661dc4bd6 100644 --- a/src/main.py +++ b/src/main.py @@ -15,7 +15,6 @@ checkpoint = utility.checkpoint(args) check = utility.checkpoint(args) teacher_model = model.Model(args, check) -#teacher_model.load_state_dict(torch.load('/home/iyj0121/AT_EDSR/model_best.pt'), strict=False) teacher_model.load(apath='/home/iyj0121/AT_EDSR/') teacher_model.eval() diff --git a/src/trainer.py b/src/trainer.py index 8023045fef46b4dd4fc9bcc8680a6908c750ea99..632eea6d2faecadc4bc8fd875261cff1c0e0fa1c 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -8,8 +8,6 @@ import torch import torch.nn.utils as utils from tqdm import tqdm -from loss import at - class Trainer(): def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss): self.args = args @@ -53,7 +51,7 @@ class Trainer(): with torch.no_grad(): t_res, _ = self.t_model(lr, 0) kd_loss = self.KD_loss(res, t_res) - loss = self.loss(sr, hr) + 0.0001*kd_loss + loss = self.loss(sr, hr) + 0.1*kd_loss loss.backward() if self.args.gclip > 0: utils.clip_grad_value_( diff --git a/src/utility.py b/src/utility.py index 0860c49fa6daba70353f437301dd6d25297b7369..8eb6f5e07c9f4f6c292ecb98f3e2231c0f187a18 100644 --- a/src/utility.py +++ b/src/utility.py @@ -92,9 +92,9 @@ class checkpoint(): def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) - #trainer.loss.plot_loss(self.dir, epoch) + trainer.loss.plot_loss(self.dir, epoch) - #self.plot_psnr(epoch) + self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt'))