From 075eeec1f639315c43e5b4f85a3401c539b707fc Mon Sep 17 00:00:00 2001 From: Sanghyun Son <thstkdgus35@snu.ac.kr> Date: Fri, 19 Oct 2018 23:19:39 +0900 Subject: [PATCH] fix reported errors --- README.md | 1 + src/data/srdata.py | 5 ++++- src/main.py | 2 +- src/option.py | 2 +- src/utility.py | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1e6c5d3..72f97ca 100755 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ We provide scripts for reproducing all the results from our paper. You can train * **imageio** * matplotlib * tqdm +* cv2 >= 3.xx (Only if you use video input/output) **Recent updates** diff --git a/src/data/srdata.py b/src/data/srdata.py index a7c9a94..5dcf99d 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -77,7 +77,10 @@ class SRData(data.Dataset): if train: n_patches = args.batch_size * args.test_every n_images = len(args.data_train) * len(self.images_hr) - self.repeat = max(n_patches // n_images, 1) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = max(n_patches // n_images, 1) # Below functions as used to prepare images def _scan(self): diff --git a/src/main.py b/src/main.py index 5ff359d..a6093f7 100644 --- a/src/main.py +++ b/src/main.py @@ -6,12 +6,12 @@ import model import loss from option import args from trainer import Trainer -from videotester import VideoTester torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) if args.data_test == 'video': + from videotester import VideoTester model = model.Model(args, checkpoint) t = VideoTester(args, model, checkpoint) t.test() diff --git a/src/option.py b/src/option.py index 6729330..0a117ec 100644 --- a/src/option.py +++ b/src/option.py @@ -114,7 +114,7 @@ parser.add_argument('--optimizer', default='ADAM', help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') -parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') diff --git a/src/utility.py b/src/utility.py index 25aabd9..6264df5 100644 --- a/src/utility.py +++ b/src/utility.py @@ -192,7 +192,7 @@ def make_optimizer(args, my_model): elif args.optimizer == 'ADAM': optimizer_function = optim.Adam kwargs = { - 'betas': args.beta, + 'betas': args.betas, 'eps': args.epsilon } elif args.optimizer == 'RMSprop': -- GitLab