From 8dba5581a7502b92de9641eb431130d6c8ca5d7f Mon Sep 17 00:00:00 2001
From: Assaad Mrad <mradassaad2@gmail.com>
Date: Mon, 2 Jan 2023 20:47:11 -0500
Subject: [PATCH] Allow for GPU usage on M1 mac (#347)

* feat: GPU usage on M1 mac

* Restore demo.sh
---
 .DS_Store             | Bin 0 -> 6148 bytes
 .gitignore            |   4 ++++
 src/model/__init__.py |  13 ++++++++++++-
 src/trainer.py        |  10 +++++++++-
 src/utility.py        |  16 +++++++++-------
 5 files changed, 34 insertions(+), 9 deletions(-)
 create mode 100644 .DS_Store

diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..70df182f09ecdc0350230e681f0b18a480ba37ea
GIT binary patch
literal 6148
zcmZQzU|@7AO)+F(5MW?n;9!8zEL;p&0Z1N%F(jFwB5WY@KxQPB7Z)Vu<R>vOFzkfN
zjZ&i_Fd71bHv~X=mxUpPA(5ekArU$MCKcpl7MB<pTw`QnW?^Mx=V0gH=7<f>$S)5r
zNh~QXc1kRY2Ju4j^K+75?8Kz7%+&ID0TJi?ypqJsywoC)lHkmg)TG3snDETJl>Bn1
z{L;LXVz6GQ1P3PvXS{%9b+wU&p^k!yv0<%_LbaiRrIC(;sgXf#EhmSlvc7dte0EN5
zUVb+uEEyRgGy^Y`hEY8X3=D7&l?4~&<>cq3LkcJcE{0Tw3WfrPRE8pkOom*BRE9i;
zlHuZMte!&)02YQ4hE#^);p1tzhtNC<OAVtmo)Ca$Ar5XlYDT3-Ltr!nMrH_r$_E8#
z)#(7G8z3}Dih+@V0o(;(1l73EAYugd10ZTZT0v?+T0t~OD+42l1(pYEWng4r0qbN0
zcS9H$7{Of<5DnJOz{mjB&cMh3*3JN{_ZcDD85kkj85p5G6h=^w24p@&I|Cy`JJ^m<
zVl)IsLjVy1%n+sksQ!0lV8GS?ho~AQM?+vV1cqe@FtWG=yEuU=O&s0>)wQ7dGyy6P
zs{KLLF(as+hUfz+0n0K$1{5X09Eco9E2!EBSH+ABkeYn7Api@ZQF=54=pO<Aq$0=s

literal 0
HcmV?d00001

diff --git a/.gitignore b/.gitignore
index fe5c524..71b4db1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -64,3 +64,7 @@ target/
 *.txt
 *.swp
 .vscode
+
+# Datasets
+data/
+test/
\ No newline at end of file
diff --git a/src/model/__init__.py b/src/model/__init__.py
index f1a1e03..2ffc49d 100644
--- a/src/model/__init__.py
+++ b/src/model/__init__.py
@@ -18,7 +18,16 @@ class Model(nn.Module):
         self.chop = args.chop
         self.precision = args.precision
         self.cpu = args.cpu
-        self.device = torch.device('cpu' if args.cpu else 'cuda')
+        if self.cpu:
+            self.device = torch.device('cpu')
+        else:
+            if torch.backends.mps.is_available():
+                self.device = torch.device('mps')
+            elif torch.cuda.is_available():
+                self.device = torch.device('cuda')
+            else:
+                self.device = torch.device('cpu')
+
         self.n_GPUs = args.n_GPUs
         self.save_models = args.save_models
 
@@ -74,6 +83,8 @@ class Model(nn.Module):
         kwargs = {}
         if cpu:
             kwargs = {'map_location': lambda storage, loc: storage}
+        else:
+            kwargs = {'map_location': self.device}
 
         if resume == -1:
             load_from = torch.load(
diff --git a/src/trainer.py b/src/trainer.py
index 849ae5c..1a6f8cf 100644
--- a/src/trainer.py
+++ b/src/trainer.py
@@ -129,7 +129,15 @@ class Trainer():
         torch.set_grad_enabled(True)
 
     def prepare(self, *args):
-        device = torch.device('cpu' if self.args.cpu else 'cuda')
+        if self.args.cpu:
+            device = torch.device('cpu')
+        else:
+            if torch.backends.mps.is_available():
+                device = torch.device('mps')
+            elif torch.cuda.is_available():
+                device = torch.device('cuda')
+            else:
+                device = torch.device('cpu')
         def _prepare(tensor):
             if self.args.precision == 'half': tensor = tensor.half()
             return tensor.to(device)
diff --git a/src/utility.py b/src/utility.py
index 7da69a7..8eb6f5e 100644
--- a/src/utility.py
+++ b/src/utility.py
@@ -41,6 +41,13 @@ class timer():
     def reset(self):
         self.acc = 0
 
+def bg_target(queue):
+    while True:
+        if not queue.empty():
+            filename, tensor = queue.get()
+            if filename is None: break
+            imageio.imwrite(filename, tensor.numpy())
+
 class checkpoint():
     def __init__(self, args):
         self.args = args
@@ -123,16 +130,11 @@ class checkpoint():
             plt.savefig(self.get_path('test_{}.pdf'.format(d)))
             plt.close(fig)
 
+    
+
     def begin_background(self):
         self.queue = Queue()
 
-        def bg_target(queue):
-            while True:
-                if not queue.empty():
-                    filename, tensor = queue.get()
-                    if filename is None: break
-                    imageio.imwrite(filename, tensor.numpy())
-        
         self.process = [
             Process(target=bg_target, args=(self.queue,)) \
             for _ in range(self.n_processes)
-- 
GitLab