From 8dba5581a7502b92de9641eb431130d6c8ca5d7f Mon Sep 17 00:00:00 2001 From: Assaad Mrad <mradassaad2@gmail.com> Date: Mon, 2 Jan 2023 20:47:11 -0500 Subject: [PATCH] Allow for GPU usage on M1 mac (#347) * feat: GPU usage on M1 mac * Restore demo.sh --- .DS_Store | Bin 0 -> 6148 bytes .gitignore | 4 ++++ src/model/__init__.py | 13 ++++++++++++- src/trainer.py | 10 +++++++++- src/utility.py | 16 +++++++++------- 5 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..70df182f09ecdc0350230e681f0b18a480ba37ea GIT binary patch literal 6148 zcmZQzU|@7AO)+F(5MW?n;9!8zEL;p&0Z1N%F(jFwB5WY@KxQPB7Z)Vu<R>vOFzkfN zjZ&i_Fd71bHv~X=mxUpPA(5ekArU$MCKcpl7MB<pTw`QnW?^Mx=V0gH=7<f>$S)5r zNh~QXc1kRY2Ju4j^K+75?8Kz7%+&ID0TJi?ypqJsywoC)lHkmg)TG3snDETJl>Bn1 z{L;LXVz6GQ1P3PvXS{%9b+wU&p^k!yv0<%_LbaiRrIC(;sgXf#EhmSlvc7dte0EN5 zUVb+uEEyRgGy^Y`hEY8X3=D7&l?4~&<>cq3LkcJcE{0Tw3WfrPRE8pkOom*BRE9i; zlHuZMte!&)02YQ4hE#^);p1tzhtNC<OAVtmo)Ca$Ar5XlYDT3-Ltr!nMrH_r$_E8# z)#(7G8z3}Dih+@V0o(;(1l73EAYugd10ZTZT0v?+T0t~OD+42l1(pYEWng4r0qbN0 zcS9H$7{Of<5DnJOz{mjB&cMh3*3JN{_ZcDD85kkj85p5G6h=^w24p@&I|Cy`JJ^m< zVl)IsLjVy1%n+sksQ!0lV8GS?ho~AQM?+vV1cqe@FtWG=yEuU=O&s0>)wQ7dGyy6P zs{KLLF(as+hUfz+0n0K$1{5X09Eco9E2!EBSH+ABkeYn7Api@ZQF=54=pO<Aq$0=s literal 0 HcmV?d00001 diff --git a/.gitignore b/.gitignore index fe5c524..71b4db1 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,7 @@ target/ *.txt *.swp .vscode + +# Datasets +data/ +test/ \ No newline at end of file diff --git a/src/model/__init__.py b/src/model/__init__.py index f1a1e03..2ffc49d 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -18,7 +18,16 @@ class Model(nn.Module): self.chop = args.chop self.precision = args.precision self.cpu = args.cpu - self.device = torch.device('cpu' if args.cpu else 'cuda') + if self.cpu: + self.device = torch.device('cpu') + else: + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + elif torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + self.n_GPUs = args.n_GPUs self.save_models = args.save_models @@ -74,6 +83,8 @@ class Model(nn.Module): kwargs = {} if cpu: kwargs = {'map_location': lambda storage, loc: storage} + else: + kwargs = {'map_location': self.device} if resume == -1: load_from = torch.load( diff --git a/src/trainer.py b/src/trainer.py index 849ae5c..1a6f8cf 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -129,7 +129,15 @@ class Trainer(): torch.set_grad_enabled(True) def prepare(self, *args): - device = torch.device('cpu' if self.args.cpu else 'cuda') + if self.args.cpu: + device = torch.device('cpu') + else: + if torch.backends.mps.is_available(): + device = torch.device('mps') + elif torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) diff --git a/src/utility.py b/src/utility.py index 7da69a7..8eb6f5e 100644 --- a/src/utility.py +++ b/src/utility.py @@ -41,6 +41,13 @@ class timer(): def reset(self): self.acc = 0 +def bg_target(queue): + while True: + if not queue.empty(): + filename, tensor = queue.get() + if filename is None: break + imageio.imwrite(filename, tensor.numpy()) + class checkpoint(): def __init__(self, args): self.args = args @@ -123,16 +130,11 @@ class checkpoint(): plt.savefig(self.get_path('test_{}.pdf'.format(d))) plt.close(fig) + + def begin_background(self): self.queue = Queue() - def bg_target(queue): - while True: - if not queue.empty(): - filename, tensor = queue.get() - if filename is None: break - imageio.imwrite(filename, tensor.numpy()) - self.process = [ Process(target=bg_target, args=(self.queue,)) \ for _ in range(self.n_processes) -- GitLab