diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..70df182f09ecdc0350230e681f0b18a480ba37ea Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index fe5c52449578f8e15199f98cc3eee3611a42a166..71b4db19254639e52fa9303fe60078d9afd23fc5 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,7 @@ target/ *.txt *.swp .vscode + +# Datasets +data/ +test/ \ No newline at end of file diff --git a/src/model/__init__.py b/src/model/__init__.py index f1a1e035f3625b3dc280a7e44f7d387eb9d8ccaa..2ffc49dca6bb454a371f91d06f93f0322eb5ddd0 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -18,7 +18,16 @@ class Model(nn.Module): self.chop = args.chop self.precision = args.precision self.cpu = args.cpu - self.device = torch.device('cpu' if args.cpu else 'cuda') + if self.cpu: + self.device = torch.device('cpu') + else: + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + elif torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + self.n_GPUs = args.n_GPUs self.save_models = args.save_models @@ -74,6 +83,8 @@ class Model(nn.Module): kwargs = {} if cpu: kwargs = {'map_location': lambda storage, loc: storage} + else: + kwargs = {'map_location': self.device} if resume == -1: load_from = torch.load( diff --git a/src/trainer.py b/src/trainer.py index 849ae5c7e47aa0dd18e369c64774edc868d06e2d..1a6f8cf24bc2a4328f3d7c41c911f611950d2f6f 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -129,7 +129,15 @@ class Trainer(): torch.set_grad_enabled(True) def prepare(self, *args): - device = torch.device('cpu' if self.args.cpu else 'cuda') + if self.args.cpu: + device = torch.device('cpu') + else: + if torch.backends.mps.is_available(): + device = torch.device('mps') + elif torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) diff --git a/src/utility.py b/src/utility.py index 7da69a701080a1a12285d6486d967befe6e6bce1..8eb6f5e07c9f4f6c292ecb98f3e2231c0f187a18 100644 --- a/src/utility.py +++ b/src/utility.py @@ -41,6 +41,13 @@ class timer(): def reset(self): self.acc = 0 +def bg_target(queue): + while True: + if not queue.empty(): + filename, tensor = queue.get() + if filename is None: break + imageio.imwrite(filename, tensor.numpy()) + class checkpoint(): def __init__(self, args): self.args = args @@ -123,16 +130,11 @@ class checkpoint(): plt.savefig(self.get_path('test_{}.pdf'.format(d))) plt.close(fig) + + def begin_background(self): self.queue = Queue() - def bg_target(queue): - while True: - if not queue.empty(): - filename, tensor = queue.get() - if filename is None: break - imageio.imwrite(filename, tensor.numpy()) - self.process = [ Process(target=bg_target, args=(self.queue,)) \ for _ in range(self.n_processes)