Skip to content
Snippets Groups Projects
main.py 985 B
Newer Older
  • Learn to ignore specific revisions
  • 영제 임's avatar
    영제 임 committed
    import torch
    
    import utility
    import data
    import model
    import loss
    from option import args
    from trainer import Trainer
    from model.sspcab_torch import MY_SSPCAB
    
    torch.manual_seed(args.seed)
    checkpoint = utility.checkpoint(args)
    
    def main():
        global model
        if args.data_test == ['video']:
            from videotester import VideoTester
            model = model.Model(args, checkpoint)
            t = VideoTester(args, model, checkpoint)
            t.test()
        else:
            if checkpoint.ok:
                loader = data.Data(args)
                _model = model.Model(args, checkpoint)
                _loss = loss.Loss(args, checkpoint) if not args.test_only else None
                sub_model = MY_SSPCAB(channels=3, reduction_ratio=8)
                sub_model.to('cuda')
                t = Trainer(args, loader, _model, _loss, checkpoint, sub_model)
                while not t.terminate():
                    t.train()
                    t.test()
    
                checkpoint.done()
    
    if __name__ == '__main__':
        main()