From 45ff7def11c66b5a6c663809a4b0ee3c0dbf36fb Mon Sep 17 00:00:00 2001
From: tabetomo <ikai.tomohiro@ktj.biglobe.ne.jp>
Date: Tue, 5 Mar 2019 12:17:18 +0900
Subject: [PATCH] fix multiprocessing issue in windows environment (#118)

---
 src/main.py | 32 ++++++++++++++++++--------------
 1 file changed, 18 insertions(+), 14 deletions(-)

diff --git a/src/main.py b/src/main.py
index a6093f7..fadbe65 100644
--- a/src/main.py
+++ b/src/main.py
@@ -10,20 +10,24 @@ from trainer import Trainer
 torch.manual_seed(args.seed)
 checkpoint = utility.checkpoint(args)
 
-if args.data_test == 'video':
-    from videotester import VideoTester
-    model = model.Model(args, checkpoint)
-    t = VideoTester(args, model, checkpoint)
-    t.test()
-else:
-    if checkpoint.ok:
-        loader = data.Data(args)
+def main():
+    global model
+    if args.data_test == 'video':
+        from videotester import VideoTester
         model = model.Model(args, checkpoint)
-        loss = loss.Loss(args, checkpoint) if not args.test_only else None
-        t = Trainer(args, loader, model, loss, checkpoint)
-        while not t.terminate():
-            t.train()
-            t.test()
+        t = VideoTester(args, model, checkpoint)
+        t.test()
+    else:
+        if checkpoint.ok:
+            loader = data.Data(args)
+            model = model.Model(args, checkpoint)
+            loss = loss.Loss(args, checkpoint) if not args.test_only else None
+            t = Trainer(args, loader, model, loss, checkpoint)
+            while not t.terminate():
+                t.train()
+                t.test()
 
-        checkpoint.done()
+            checkpoint.done()
 
+if __name__ == '__main__':
+    main()
-- 
GitLab