Skip to content
Snippets Groups Projects
Commit 075eeec1 authored by Sanghyun Son's avatar Sanghyun Son
Browse files

fix reported errors

parent 70294bb6
Branches
No related tags found
1 merge request!1Jan 09, 2018 updates
...@@ -33,6 +33,7 @@ We provide scripts for reproducing all the results from our paper. You can train ...@@ -33,6 +33,7 @@ We provide scripts for reproducing all the results from our paper. You can train
* **imageio** * **imageio**
* matplotlib * matplotlib
* tqdm * tqdm
* cv2 >= 3.xx (Only if you use video input/output)
**Recent updates** **Recent updates**
......
...@@ -77,6 +77,9 @@ class SRData(data.Dataset): ...@@ -77,6 +77,9 @@ class SRData(data.Dataset):
if train: if train:
n_patches = args.batch_size * args.test_every n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr) n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1) self.repeat = max(n_patches // n_images, 1)
# Below functions as used to prepare images # Below functions as used to prepare images
......
...@@ -6,12 +6,12 @@ import model ...@@ -6,12 +6,12 @@ import model
import loss import loss
from option import args from option import args
from trainer import Trainer from trainer import Trainer
from videotester import VideoTester
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args) checkpoint = utility.checkpoint(args)
if args.data_test == 'video': if args.data_test == 'video':
from videotester import VideoTester
model = model.Model(args, checkpoint) model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint) t = VideoTester(args, model, checkpoint)
t.test() t.test()
......
...@@ -114,7 +114,7 @@ parser.add_argument('--optimizer', default='ADAM', ...@@ -114,7 +114,7 @@ parser.add_argument('--optimizer', default='ADAM',
help='optimizer to use (SGD | ADAM | RMSprop)') help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9, parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum') help='SGD momentum')
parser.add_argument('--beta', type=tuple, default=(0.9, 0.999), parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta') help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8, parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability') help='ADAM epsilon for numerical stability')
......
...@@ -192,7 +192,7 @@ def make_optimizer(args, my_model): ...@@ -192,7 +192,7 @@ def make_optimizer(args, my_model):
elif args.optimizer == 'ADAM': elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam optimizer_function = optim.Adam
kwargs = { kwargs = {
'betas': args.beta, 'betas': args.betas,
'eps': args.epsilon 'eps': args.epsilon
} }
elif args.optimizer == 'RMSprop': elif args.optimizer == 'RMSprop':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment