From e7d9d6b3ac4b3c826de1fe8a89ac950ece13eae0 Mon Sep 17 00:00:00 2001
From: im_yeong_jae <iyj0121@ajou.ac.kr>
Date: Wed, 3 May 2023 22:51:34 +0900
Subject: [PATCH] teacher model change

---
 src/loss/at.py | 4 ++--
 src/main.py    | 1 -
 src/trainer.py | 4 +---
 src/utility.py | 4 ++--
 4 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/src/loss/at.py b/src/loss/at.py
index 35965a2..0f7b5a0 100644
--- a/src/loss/at.py
+++ b/src/loss/at.py
@@ -27,7 +27,7 @@ class AT(nn.Module):
 	def attention_map(self, fm, eps=1e-6):
 		am = torch.pow(torch.abs(fm), self.p)
 		am = torch.sum(am, dim=1, keepdim=True)
-		#norm = torch.norm(am, keepdim=True)#, dim=(2,3)
-		#am = torch.div(am, norm+eps)
+		norm = torch.norm(am, keepdim=True)#, dim=(2,3)
+		am = torch.div(am, norm+eps)
 
 		return am
\ No newline at end of file
diff --git a/src/main.py b/src/main.py
index 67721a3..f3a63d7 100644
--- a/src/main.py
+++ b/src/main.py
@@ -15,7 +15,6 @@ checkpoint = utility.checkpoint(args)
 check = utility.checkpoint(args)
 
 teacher_model = model.Model(args, check)
-#teacher_model.load_state_dict(torch.load('/home/iyj0121/AT_EDSR/model_best.pt'), strict=False)
 teacher_model.load(apath='/home/iyj0121/AT_EDSR/')
 teacher_model.eval()
 
diff --git a/src/trainer.py b/src/trainer.py
index 8023045..632eea6 100644
--- a/src/trainer.py
+++ b/src/trainer.py
@@ -8,8 +8,6 @@ import torch
 import torch.nn.utils as utils
 from tqdm import tqdm
 
-from loss import at
-
 class Trainer():
     def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss):
         self.args = args
@@ -53,7 +51,7 @@ class Trainer():
             with torch.no_grad():
                 t_res, _ = self.t_model(lr, 0)
             kd_loss = self.KD_loss(res, t_res)
-            loss = self.loss(sr, hr) + 0.0001*kd_loss
+            loss = self.loss(sr, hr) + 0.1*kd_loss
             loss.backward()
             if self.args.gclip > 0:
                 utils.clip_grad_value_(
diff --git a/src/utility.py b/src/utility.py
index 0860c49..8eb6f5e 100644
--- a/src/utility.py
+++ b/src/utility.py
@@ -92,9 +92,9 @@ class checkpoint():
     def save(self, trainer, epoch, is_best=False):
         trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
         trainer.loss.save(self.dir)
-        #trainer.loss.plot_loss(self.dir, epoch)
+        trainer.loss.plot_loss(self.dir, epoch)
 
-        #self.plot_psnr(epoch)
+        self.plot_psnr(epoch)
         trainer.optimizer.save(self.dir)
         torch.save(self.log, self.get_path('psnr_log.pt'))
 
-- 
GitLab