Skip to content
Snippets Groups Projects
Unverified Commit 8dba5581 authored by Assaad Mrad's avatar Assaad Mrad Committed by GitHub
Browse files

Allow for GPU usage on M1 mac (#347)

* feat: GPU usage on M1 mac

* Restore demo.sh
parent 9d3bb0ec
No related branches found
No related tags found
1 merge request!1Jan 09, 2018 updates
.DS_Store 0 → 100644
File added
......@@ -64,3 +64,7 @@ target/
*.txt
*.swp
.vscode
# Datasets
data/
test/
\ No newline at end of file
......@@ -18,7 +18,16 @@ class Model(nn.Module):
self.chop = args.chop
self.precision = args.precision
self.cpu = args.cpu
self.device = torch.device('cpu' if args.cpu else 'cuda')
if self.cpu:
self.device = torch.device('cpu')
else:
if torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
self.n_GPUs = args.n_GPUs
self.save_models = args.save_models
......@@ -74,6 +83,8 @@ class Model(nn.Module):
kwargs = {}
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {'map_location': self.device}
if resume == -1:
load_from = torch.load(
......
......@@ -129,7 +129,15 @@ class Trainer():
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
if self.args.cpu:
device = torch.device('cpu')
else:
if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
......
......@@ -41,6 +41,13 @@ class timer():
def reset(self):
self.acc = 0
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
class checkpoint():
def __init__(self, args):
self.args = args
......@@ -123,16 +130,11 @@ class checkpoint():
plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig)
def begin_background(self):
self.queue = Queue()
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
self.process = [
Process(target=bg_target, args=(self.queue,)) \
for _ in range(self.n_processes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment