diff --git a/src/loss/at.py b/src/loss/at.py
new file mode 100644
index 0000000000000000000000000000000000000000..35965a2d7333c06c2fbd28c71968440d34272c84
--- /dev/null
+++ b/src/loss/at.py
@@ -0,0 +1,33 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+'''
+AT with sum of absolute values with power p
+'''
+class AT(nn.Module):
+	'''
+	Paying More Attention to Attention: Improving the Performance of Convolutional
+	Neural Netkworks wia Attention Transfer
+	https://arxiv.org/pdf/1612.03928.pdf
+	'''
+	def __init__(self, p):
+		super(AT, self).__init__()
+		self.p = p
+
+	def forward(self, fm_s, fm_t):
+		loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))
+
+		return loss
+
+	def attention_map(self, fm, eps=1e-6):
+		am = torch.pow(torch.abs(fm), self.p)
+		am = torch.sum(am, dim=1, keepdim=True)
+		#norm = torch.norm(am, keepdim=True)#, dim=(2,3)
+		#am = torch.div(am, norm+eps)
+
+		return am
\ No newline at end of file
diff --git a/src/main.py b/src/main.py
index dbfac3e008d04cc72f438179fdce265aa1f079ad..d592f12cd3150408adab0383d1f7d5ec42475bf1 100644
--- a/src/main.py
+++ b/src/main.py
@@ -7,9 +7,17 @@ import loss
 from option import args
 from trainer import Trainer
 
+from loss import at
+
 torch.manual_seed(args.seed)
 checkpoint = utility.checkpoint(args)
 
+check = utility.checkpoint(args)
+
+teacher_model = model.Model(args, check)
+teacher_model.load_state_dict(torch.load('/home/iyj0121/EDSR-PyTorch/experiment/EDSR_x2.pt'), strict=False)
+teacher_model.eval()
+
 def main():
     global model
     if args.data_test == ['video']:
@@ -22,7 +30,8 @@ def main():
             loader = data.Data(args)
             _model = model.Model(args, checkpoint)
             _loss = loss.Loss(args, checkpoint) if not args.test_only else None
-            t = Trainer(args, loader, _model, _loss, checkpoint)
+            kd_loss = at.AT(p=2.0)
+            t = Trainer(args, loader, _model, _loss, checkpoint, teacher_model, kd_loss)
             while not t.terminate():
                 t.train()
                 t.test()
diff --git a/src/model/__init__.py b/src/model/__init__.py
index 2ffc49dca6bb454a371f91d06f93f0322eb5ddd0..6d4220bf49267f36fed11365013426ca32e856d0 100644
--- a/src/model/__init__.py
+++ b/src/model/__init__.py
@@ -21,9 +21,9 @@ class Model(nn.Module):
         if self.cpu:
             self.device = torch.device('cpu')
         else:
-            if torch.backends.mps.is_available():
-                self.device = torch.device('mps')
-            elif torch.cuda.is_available():
+            #if torch.backends.mps.is_available():
+            #    self.device = torch.device('mps')
+            if torch.cuda.is_available():
                 self.device = torch.device('cuda')
             else:
                 self.device = torch.device('cpu')
diff --git a/src/model/edsr.py b/src/model/edsr.py
index ef4ffb1e2f9c485140e92b3b06ddc93ab3877fba..74c6e61d622d44d1e48c071015a7305cfb108c4b 100644
--- a/src/model/edsr.py
+++ b/src/model/edsr.py
@@ -62,7 +62,7 @@ class EDSR(nn.Module):
         x = self.tail(res)
         x = self.add_mean(x)
 
-        return x 
+        return res, x 
 
     def load_state_dict(self, state_dict, strict=True):
         own_state = self.state_dict()
diff --git a/src/option.py b/src/option.py
index 8ec9634813b2a3b4799341e0b23e9bac856fc6e6..097343a3b1ee6453f6ce032aeddf279cfde2b37a 100644
--- a/src/option.py
+++ b/src/option.py
@@ -19,7 +19,7 @@ parser.add_argument('--seed', type=int, default=1,
                     help='random seed')
 
 # Data specifications
-parser.add_argument('--dir_data', type=str, default='../../../dataset',
+parser.add_argument('--dir_data', type=str, default='/home/iyj0121/EDSR-PyTorch/dataset',
                     help='dataset directory')
 parser.add_argument('--dir_demo', type=str, default='../test',
                     help='demo image directory')
@@ -87,9 +87,9 @@ parser.add_argument('--reset', action='store_true',
                     help='reset the training')
 parser.add_argument('--test_every', type=int, default=1000,
                     help='do test per every N batches')
-parser.add_argument('--epochs', type=int, default=300,
+parser.add_argument('--epochs', type=int, default=100,
                     help='number of epochs to train')
-parser.add_argument('--batch_size', type=int, default=16,
+parser.add_argument('--batch_size', type=int, default=8,
                     help='input batch size for training')
 parser.add_argument('--split_batch', type=int, default=1,
                     help='split the batch into smaller chunks')
diff --git a/src/trainer.py b/src/trainer.py
index 1a6f8cf24bc2a4328f3d7c41c911f611950d2f6f..07d31a1a6563f15fa6632233609137f325abd7e1 100644
--- a/src/trainer.py
+++ b/src/trainer.py
@@ -8,8 +8,10 @@ import torch
 import torch.nn.utils as utils
 from tqdm import tqdm
 
+from loss import at
+
 class Trainer():
-    def __init__(self, args, loader, my_model, my_loss, ckp):
+    def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss):
         self.args = args
         self.scale = args.scale
 
@@ -18,6 +20,8 @@ class Trainer():
         self.loader_test = loader.loader_test
         self.model = my_model
         self.loss = my_loss
+        self.KD_loss = kd_loss
+        self.t_model = teacher_model
         self.optimizer = utility.make_optimizer(args, self.model)
 
         if self.args.load != '':
@@ -45,8 +49,10 @@ class Trainer():
             timer_model.tic()
 
             self.optimizer.zero_grad()
-            sr = self.model(lr, 0)
-            loss = self.loss(sr, hr)
+            res, sr = self.model(lr, 0)
+            t_res, _ = self.t_model(lr, 0)
+            kd_loss = self.KD_loss(res, t_res)
+            loss = self.loss(sr, hr) + 0.1*kd_loss
             loss.backward()
             if self.args.gclip > 0:
                 utils.clip_grad_value_(
@@ -88,7 +94,7 @@ class Trainer():
                 d.dataset.set_scale(idx_scale)
                 for lr, hr, filename in tqdm(d, ncols=80):
                     lr, hr = self.prepare(lr, hr)
-                    sr = self.model(lr, idx_scale)
+                    _, sr = self.model(lr, idx_scale)
                     sr = utility.quantize(sr, self.args.rgb_range)
 
                     save_list = [sr]
@@ -132,9 +138,9 @@ class Trainer():
         if self.args.cpu:
             device = torch.device('cpu')
         else:
-            if torch.backends.mps.is_available():
-                device = torch.device('mps')
-            elif torch.cuda.is_available():
+            #if torch.backends.mps.is_available():
+            #    device = torch.device('mps')
+            if torch.cuda.is_available():
                 device = torch.device('cuda')
             else:
                 device = torch.device('cpu')
diff --git a/src/utility.py b/src/utility.py
index 8eb6f5e07c9f4f6c292ecb98f3e2231c0f187a18..0860c49fa6daba70353f437301dd6d25297b7369 100644
--- a/src/utility.py
+++ b/src/utility.py
@@ -92,9 +92,9 @@ class checkpoint():
     def save(self, trainer, epoch, is_best=False):
         trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
         trainer.loss.save(self.dir)
-        trainer.loss.plot_loss(self.dir, epoch)
+        #trainer.loss.plot_loss(self.dir, epoch)
 
-        self.plot_psnr(epoch)
+        #self.plot_psnr(epoch)
         trainer.optimizer.save(self.dir)
         torch.save(self.log, self.get_path('psnr_log.pt'))