Skip to content
Snippets Groups Projects
Commit e7d9d6b3 authored by 영제 임's avatar 영제 임
Browse files

teacher model change

parent 58b8660a
Branches
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ class AT(nn.Module): ...@@ -27,7 +27,7 @@ class AT(nn.Module):
def attention_map(self, fm, eps=1e-6): def attention_map(self, fm, eps=1e-6):
am = torch.pow(torch.abs(fm), self.p) am = torch.pow(torch.abs(fm), self.p)
am = torch.sum(am, dim=1, keepdim=True) am = torch.sum(am, dim=1, keepdim=True)
#norm = torch.norm(am, keepdim=True)#, dim=(2,3) norm = torch.norm(am, keepdim=True)#, dim=(2,3)
#am = torch.div(am, norm+eps) am = torch.div(am, norm+eps)
return am return am
\ No newline at end of file
...@@ -15,7 +15,6 @@ checkpoint = utility.checkpoint(args) ...@@ -15,7 +15,6 @@ checkpoint = utility.checkpoint(args)
check = utility.checkpoint(args) check = utility.checkpoint(args)
teacher_model = model.Model(args, check) teacher_model = model.Model(args, check)
#teacher_model.load_state_dict(torch.load('/home/iyj0121/AT_EDSR/model_best.pt'), strict=False)
teacher_model.load(apath='/home/iyj0121/AT_EDSR/') teacher_model.load(apath='/home/iyj0121/AT_EDSR/')
teacher_model.eval() teacher_model.eval()
......
...@@ -8,8 +8,6 @@ import torch ...@@ -8,8 +8,6 @@ import torch
import torch.nn.utils as utils import torch.nn.utils as utils
from tqdm import tqdm from tqdm import tqdm
from loss import at
class Trainer(): class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss): def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss):
self.args = args self.args = args
...@@ -53,7 +51,7 @@ class Trainer(): ...@@ -53,7 +51,7 @@ class Trainer():
with torch.no_grad(): with torch.no_grad():
t_res, _ = self.t_model(lr, 0) t_res, _ = self.t_model(lr, 0)
kd_loss = self.KD_loss(res, t_res) kd_loss = self.KD_loss(res, t_res)
loss = self.loss(sr, hr) + 0.0001*kd_loss loss = self.loss(sr, hr) + 0.1*kd_loss
loss.backward() loss.backward()
if self.args.gclip > 0: if self.args.gclip > 0:
utils.clip_grad_value_( utils.clip_grad_value_(
......
...@@ -92,9 +92,9 @@ class checkpoint(): ...@@ -92,9 +92,9 @@ class checkpoint():
def save(self, trainer, epoch, is_best=False): def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
trainer.loss.save(self.dir) trainer.loss.save(self.dir)
#trainer.loss.plot_loss(self.dir, epoch) trainer.loss.plot_loss(self.dir, epoch)
#self.plot_psnr(epoch) self.plot_psnr(epoch)
trainer.optimizer.save(self.dir) trainer.optimizer.save(self.dir)
torch.save(self.log, self.get_path('psnr_log.pt')) torch.save(self.log, self.get_path('psnr_log.pt'))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment