diff --git a/src/loss/at.py b/src/loss/at.py new file mode 100644 index 0000000000000000000000000000000000000000..35965a2d7333c06c2fbd28c71968440d34272c84 --- /dev/null +++ b/src/loss/at.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +import torch +import torch.nn as nn +import torch.nn.functional as F + + +''' +AT with sum of absolute values with power p +''' +class AT(nn.Module): + ''' + Paying More Attention to Attention: Improving the Performance of Convolutional + Neural Netkworks wia Attention Transfer + https://arxiv.org/pdf/1612.03928.pdf + ''' + def __init__(self, p): + super(AT, self).__init__() + self.p = p + + def forward(self, fm_s, fm_t): + loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t)) + + return loss + + def attention_map(self, fm, eps=1e-6): + am = torch.pow(torch.abs(fm), self.p) + am = torch.sum(am, dim=1, keepdim=True) + #norm = torch.norm(am, keepdim=True)#, dim=(2,3) + #am = torch.div(am, norm+eps) + + return am \ No newline at end of file diff --git a/src/main.py b/src/main.py index dbfac3e008d04cc72f438179fdce265aa1f079ad..d592f12cd3150408adab0383d1f7d5ec42475bf1 100644 --- a/src/main.py +++ b/src/main.py @@ -7,9 +7,17 @@ import loss from option import args from trainer import Trainer +from loss import at + torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) +check = utility.checkpoint(args) + +teacher_model = model.Model(args, check) +teacher_model.load_state_dict(torch.load('/home/iyj0121/EDSR-PyTorch/experiment/EDSR_x2.pt'), strict=False) +teacher_model.eval() + def main(): global model if args.data_test == ['video']: @@ -22,7 +30,8 @@ def main(): loader = data.Data(args) _model = model.Model(args, checkpoint) _loss = loss.Loss(args, checkpoint) if not args.test_only else None - t = Trainer(args, loader, _model, _loss, checkpoint) + kd_loss = at.AT(p=2.0) + t = Trainer(args, loader, _model, _loss, checkpoint, teacher_model, kd_loss) while not t.terminate(): t.train() t.test() diff --git a/src/model/__init__.py b/src/model/__init__.py index 2ffc49dca6bb454a371f91d06f93f0322eb5ddd0..6d4220bf49267f36fed11365013426ca32e856d0 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -21,9 +21,9 @@ class Model(nn.Module): if self.cpu: self.device = torch.device('cpu') else: - if torch.backends.mps.is_available(): - self.device = torch.device('mps') - elif torch.cuda.is_available(): + #if torch.backends.mps.is_available(): + # self.device = torch.device('mps') + if torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') diff --git a/src/model/edsr.py b/src/model/edsr.py index ef4ffb1e2f9c485140e92b3b06ddc93ab3877fba..74c6e61d622d44d1e48c071015a7305cfb108c4b 100644 --- a/src/model/edsr.py +++ b/src/model/edsr.py @@ -62,7 +62,7 @@ class EDSR(nn.Module): x = self.tail(res) x = self.add_mean(x) - return x + return res, x def load_state_dict(self, state_dict, strict=True): own_state = self.state_dict() diff --git a/src/option.py b/src/option.py index 8ec9634813b2a3b4799341e0b23e9bac856fc6e6..097343a3b1ee6453f6ce032aeddf279cfde2b37a 100644 --- a/src/option.py +++ b/src/option.py @@ -19,7 +19,7 @@ parser.add_argument('--seed', type=int, default=1, help='random seed') # Data specifications -parser.add_argument('--dir_data', type=str, default='../../../dataset', +parser.add_argument('--dir_data', type=str, default='/home/iyj0121/EDSR-PyTorch/dataset', help='dataset directory') parser.add_argument('--dir_demo', type=str, default='../test', help='demo image directory') @@ -87,9 +87,9 @@ parser.add_argument('--reset', action='store_true', help='reset the training') parser.add_argument('--test_every', type=int, default=1000, help='do test per every N batches') -parser.add_argument('--epochs', type=int, default=300, +parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') -parser.add_argument('--batch_size', type=int, default=16, +parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training') parser.add_argument('--split_batch', type=int, default=1, help='split the batch into smaller chunks') diff --git a/src/trainer.py b/src/trainer.py index 1a6f8cf24bc2a4328f3d7c41c911f611950d2f6f..07d31a1a6563f15fa6632233609137f325abd7e1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -8,8 +8,10 @@ import torch import torch.nn.utils as utils from tqdm import tqdm +from loss import at + class Trainer(): - def __init__(self, args, loader, my_model, my_loss, ckp): + def __init__(self, args, loader, my_model, my_loss, ckp, teacher_model, kd_loss): self.args = args self.scale = args.scale @@ -18,6 +20,8 @@ class Trainer(): self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss + self.KD_loss = kd_loss + self.t_model = teacher_model self.optimizer = utility.make_optimizer(args, self.model) if self.args.load != '': @@ -45,8 +49,10 @@ class Trainer(): timer_model.tic() self.optimizer.zero_grad() - sr = self.model(lr, 0) - loss = self.loss(sr, hr) + res, sr = self.model(lr, 0) + t_res, _ = self.t_model(lr, 0) + kd_loss = self.KD_loss(res, t_res) + loss = self.loss(sr, hr) + 0.1*kd_loss loss.backward() if self.args.gclip > 0: utils.clip_grad_value_( @@ -88,7 +94,7 @@ class Trainer(): d.dataset.set_scale(idx_scale) for lr, hr, filename in tqdm(d, ncols=80): lr, hr = self.prepare(lr, hr) - sr = self.model(lr, idx_scale) + _, sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range) save_list = [sr] @@ -132,9 +138,9 @@ class Trainer(): if self.args.cpu: device = torch.device('cpu') else: - if torch.backends.mps.is_available(): - device = torch.device('mps') - elif torch.cuda.is_available(): + #if torch.backends.mps.is_available(): + # device = torch.device('mps') + if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') diff --git a/src/utility.py b/src/utility.py index 8eb6f5e07c9f4f6c292ecb98f3e2231c0f187a18..0860c49fa6daba70353f437301dd6d25297b7369 100644 --- a/src/utility.py +++ b/src/utility.py @@ -92,9 +92,9 @@ class checkpoint(): def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) - trainer.loss.plot_loss(self.dir, epoch) + #trainer.loss.plot_loss(self.dir, epoch) - self.plot_psnr(epoch) + #self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt'))