diff --git a/src/trainer.py b/src/trainer.py index 07d31a1a6563f15fa6632233609137f325abd7e1..9c2e1531072554d2ec27d5c96c98f27f6f2378da 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -52,7 +52,7 @@ class Trainer(): res, sr = self.model(lr, 0) t_res, _ = self.t_model(lr, 0) kd_loss = self.KD_loss(res, t_res) - loss = self.loss(sr, hr) + 0.1*kd_loss + loss = self.loss(sr, hr) + 0.01*kd_loss loss.backward() if self.args.gclip > 0: utils.clip_grad_value_(