diff --git a/README.md b/README.md index 1e6c5d31a8d925fd8475e29a6aabf333af48dd1e..72f97ca27ce3c6def57937ed7002a260d2de00cd 100755 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ We provide scripts for reproducing all the results from our paper. You can train * **imageio** * matplotlib * tqdm +* cv2 >= 3.xx (Only if you use video input/output) **Recent updates** diff --git a/src/data/srdata.py b/src/data/srdata.py index a7c9a947fdf492f640fc329980ff43c4efcdb8c0..5dcf99d2932d2d987e54663bba9479645ee2f924 100644 --- a/src/data/srdata.py +++ b/src/data/srdata.py @@ -77,7 +77,10 @@ class SRData(data.Dataset): if train: n_patches = args.batch_size * args.test_every n_images = len(args.data_train) * len(self.images_hr) - self.repeat = max(n_patches // n_images, 1) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = max(n_patches // n_images, 1) # Below functions as used to prepare images def _scan(self): diff --git a/src/main.py b/src/main.py index 5ff359de1111c25021df1acf2205add56eac56b4..a6093f70bfb25d348e3c0b1b820b92a2ef060071 100644 --- a/src/main.py +++ b/src/main.py @@ -6,12 +6,12 @@ import model import loss from option import args from trainer import Trainer -from videotester import VideoTester torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) if args.data_test == 'video': + from videotester import VideoTester model = model.Model(args, checkpoint) t = VideoTester(args, model, checkpoint) t.test() diff --git a/src/option.py b/src/option.py index 672933096c14ddb4bf3b05c14382f0d4cdd206e5..0a117ec65e73ef642ae6b956525d1c704339c996 100644 --- a/src/option.py +++ b/src/option.py @@ -114,7 +114,7 @@ parser.add_argument('--optimizer', default='ADAM', help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') -parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') diff --git a/src/utility.py b/src/utility.py index 25aabd9f554fce1cbd6d830d8985ed9c6dad5276..6264df57113943405e92f448a63f21b29542c3f4 100644 --- a/src/utility.py +++ b/src/utility.py @@ -192,7 +192,7 @@ def make_optimizer(args, my_model): elif args.optimizer == 'ADAM': optimizer_function = optim.Adam kwargs = { - 'betas': args.beta, + 'betas': args.betas, 'eps': args.epsilon } elif args.optimizer == 'RMSprop':