From 45ff7def11c66b5a6c663809a4b0ee3c0dbf36fb Mon Sep 17 00:00:00 2001 From: tabetomo <ikai.tomohiro@ktj.biglobe.ne.jp> Date: Tue, 5 Mar 2019 12:17:18 +0900 Subject: [PATCH] fix multiprocessing issue in windows environment (#118) --- src/main.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/main.py b/src/main.py index a6093f7..fadbe65 100644 --- a/src/main.py +++ b/src/main.py @@ -10,20 +10,24 @@ from trainer import Trainer torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) -if args.data_test == 'video': - from videotester import VideoTester - model = model.Model(args, checkpoint) - t = VideoTester(args, model, checkpoint) - t.test() -else: - if checkpoint.ok: - loader = data.Data(args) +def main(): + global model + if args.data_test == 'video': + from videotester import VideoTester model = model.Model(args, checkpoint) - loss = loss.Loss(args, checkpoint) if not args.test_only else None - t = Trainer(args, loader, model, loss, checkpoint) - while not t.terminate(): - t.train() - t.test() + t = VideoTester(args, model, checkpoint) + t.test() + else: + if checkpoint.ok: + loader = data.Data(args) + model = model.Model(args, checkpoint) + loss = loss.Loss(args, checkpoint) if not args.test_only else None + t = Trainer(args, loader, model, loss, checkpoint) + while not t.terminate(): + t.train() + t.test() - checkpoint.done() + checkpoint.done() +if __name__ == '__main__': + main() -- GitLab