diff --git a/src/main.py b/src/main.py index a6093f70bfb25d348e3c0b1b820b92a2ef060071..fadbe657f6d7ebe195ec23089078db390eecf6fd 100644 --- a/src/main.py +++ b/src/main.py @@ -10,20 +10,24 @@ from trainer import Trainer torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) -if args.data_test == 'video': - from videotester import VideoTester - model = model.Model(args, checkpoint) - t = VideoTester(args, model, checkpoint) - t.test() -else: - if checkpoint.ok: - loader = data.Data(args) +def main(): + global model + if args.data_test == 'video': + from videotester import VideoTester model = model.Model(args, checkpoint) - loss = loss.Loss(args, checkpoint) if not args.test_only else None - t = Trainer(args, loader, model, loss, checkpoint) - while not t.terminate(): - t.train() - t.test() + t = VideoTester(args, model, checkpoint) + t.test() + else: + if checkpoint.ok: + loader = data.Data(args) + model = model.Model(args, checkpoint) + loss = loss.Loss(args, checkpoint) if not args.test_only else None + t = Trainer(args, loader, model, loss, checkpoint) + while not t.terminate(): + t.train() + t.test() - checkpoint.done() + checkpoint.done() +if __name__ == '__main__': + main()