diff --git a/src/main.py b/src/main.py
index a6093f70bfb25d348e3c0b1b820b92a2ef060071..fadbe657f6d7ebe195ec23089078db390eecf6fd 100644
--- a/src/main.py
+++ b/src/main.py
@@ -10,20 +10,24 @@ from trainer import Trainer
 torch.manual_seed(args.seed)
 checkpoint = utility.checkpoint(args)
 
-if args.data_test == 'video':
-    from videotester import VideoTester
-    model = model.Model(args, checkpoint)
-    t = VideoTester(args, model, checkpoint)
-    t.test()
-else:
-    if checkpoint.ok:
-        loader = data.Data(args)
+def main():
+    global model
+    if args.data_test == 'video':
+        from videotester import VideoTester
         model = model.Model(args, checkpoint)
-        loss = loss.Loss(args, checkpoint) if not args.test_only else None
-        t = Trainer(args, loader, model, loss, checkpoint)
-        while not t.terminate():
-            t.train()
-            t.test()
+        t = VideoTester(args, model, checkpoint)
+        t.test()
+    else:
+        if checkpoint.ok:
+            loader = data.Data(args)
+            model = model.Model(args, checkpoint)
+            loss = loss.Loss(args, checkpoint) if not args.test_only else None
+            t = Trainer(args, loader, model, loss, checkpoint)
+            while not t.terminate():
+                t.train()
+                t.test()
 
-        checkpoint.done()
+            checkpoint.done()
 
+if __name__ == '__main__':
+    main()