diff --git a/README.md b/README.md index 4d9a413786d67d7d24e6b2cee88baf8d8719706f..1e6c5d31a8d925fd8475e29a6aabf333af48dd1e 100755 --- a/README.md +++ b/README.md @@ -36,10 +36,9 @@ We provide scripts for reproducing all the results from our paper. You can train **Recent updates** -* July 22, 2018 - * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. - * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). - +* Oct 18, 2018 + * with ``--pre_train download``, pretrained models will be automatically downloaded from server. + * Supports video input/output (inference only). Try with ``--data_test video --dir_demo [video file directory]``. ## Code Clone this repository into any place you want. @@ -167,3 +166,8 @@ sh demo.sh * Compatible with PyTorch 0.4.0 * Please use the legacy/0.3.1 branch if you are using the old version of PyTorch. * Minor bug fixes + +* July 22, 2018 + * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. + * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). + diff --git a/experiment/model/.gitignore b/experiment/model/.gitignore deleted file mode 100644 index fc5177b5dc4c273b2ac4e2677911334f3509426e..0000000000000000000000000000000000000000 --- a/experiment/model/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -* -!.gitignore -!*.pt diff --git a/experiment/model/MDSR_baseline.pt b/experiment/model/MDSR_baseline.pt deleted file mode 100644 index 307c374aa623a545769faeb5a2fa3092b4f268f5..0000000000000000000000000000000000000000 Binary files a/experiment/model/MDSR_baseline.pt and /dev/null differ diff --git a/experiment/model/MDSR_baseline_jpeg.pt b/experiment/model/MDSR_baseline_jpeg.pt deleted file mode 100644 index 94e6a6ac12d52b191df6f6e9354917edf6e7c984..0000000000000000000000000000000000000000 Binary files a/experiment/model/MDSR_baseline_jpeg.pt and /dev/null differ diff --git a/src/data/__init__.py b/src/data/__init__.py index 26b43f2992d89b3605b2c2663159febeebc35555..827320054c221b7f7e4b25baec2e2a57259807f0 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -6,6 +6,7 @@ from torch.utils.data import ConcatDataset class MyConcatDataset(ConcatDataset): def __init__(self, datasets): super(MyConcatDataset, self).__init__(datasets) + self.train = datasets[0].train def set_scale(self, idx_scale): for d in self.datasets: diff --git a/src/model/__init__.py b/src/model/__init__.py index a2cc30d63fd417b965be92135e97cdfe7ee6cda8..ee35f5c932f28ff3bd1a4822ffc3054d5340965d 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -75,7 +75,7 @@ class Model(nn.Module): for s in save_dirs: torch.save(target.state_dict(), s) - def load(self, apath, pre_train='.', resume=-1, cpu=False): + def load(self, apath, pre_train='', resume=-1, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: @@ -97,7 +97,7 @@ class Model(nn.Module): model_dir=dir_model, **kwargs ) - elif pre_train != '': + elif pre_train: print('Load the model from {}'.format(pre_train)) load_from = torch.load(pre_train, **kwargs) else: diff --git a/src/trainer.py b/src/trainer.py index a05fa0a515b2c309e0a13a9606db54c5e9fd7527..fb73373d2702e7940ad997a3866ff7acadd1064c 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -97,9 +97,11 @@ class Trainer(): self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( sr, hr, scale, self.args.rgb_range, dataset=d ) - if self.args.save_gt: save_list.extend([lr, hr]) + if self.args.save_gt: + save_list.extend([lr, hr]) - self.ckp.save_results(d, filename[0], save_list, scale) + if self.args.save_results: + self.ckp.save_results(d, filename[0], save_list, scale) self.ckp.log[-1, idx_data, idx_scale] /= len(d) best = self.ckp.log.max(0)