Skip to content

Commit 1aa2aa5

Browse files
committed
test without env variable
1 parent 39dc87f commit 1aa2aa5

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

deepecho/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
__email__ = '[email protected]'
55
__version__ = '0.7.1.dev0'
66
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
7+
import os
8+
9+
# Enable fallback so ops not implemented on MPS run on CPU
10+
# https://github.com/pytorch/pytorch/issues/77764
11+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
712

813
from deepecho.demo import load_demo
914
from deepecho.models.basic_gan import BasicGANModel

deepecho/models/_utils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
import sys
1+
import os
2+
3+
# Enable fallback so ops not implemented on MPS run on CPU
4+
# https://github.com/pytorch/pytorch/issues/77764
5+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
6+
7+
import platform
28
import warnings
39

410
import torch
@@ -32,15 +38,16 @@ def _validate_gpu_parameters(enable_gpu, cuda):
3238
def _set_device(enable_gpu):
3339
"""Set the torch device based on the `enable_gpu` parameter and system capabilities."""
3440
if enable_gpu:
35-
if sys.platform == 'darwin': # macOS
36-
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
37-
device = 'mps'
38-
else:
39-
device = 'cpu'
40-
else: # Linux/Windows
41-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
42-
else:
43-
device = 'cpu'
41+
if (
42+
platform.machine() == 'arm64'
43+
and getattr(torch.backends, 'mps', None)
44+
and torch.backends.mps.is_available()
45+
):
46+
device = 'mps'
47+
else:
48+
device = 'cpu'
49+
else: # Linux/Windows
50+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
4451

4552
return torch.device(device)
4653

deepecho/models/par.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from tqdm import tqdm
99

10+
from deepecho.models._utils import validate_and_set_device
1011
from deepecho.models.base import DeepEcho
1112

1213
LOGGER = logging.getLogger(__name__)
@@ -98,18 +99,11 @@ class PARModel(DeepEcho):
9899
Whether to print progress to console or not.
99100
"""
100101

101-
def __init__(self, epochs=128, sample_size=1, cuda=True, verbose=True):
102+
def __init__(self, epochs=128, sample_size=1, enable_gpu=True, verbose=True, cuda=None):
102103
self.epochs = epochs
103104
self.sample_size = sample_size
104-
105-
if not cuda or not torch.cuda.is_available():
106-
device = 'cpu'
107-
elif isinstance(cuda, str):
108-
device = cuda
109-
else:
110-
device = 'cuda'
111-
112-
self.device = torch.device(device)
105+
self.device = validate_and_set_device(enable_gpu=enable_gpu, cuda=cuda)
106+
self._enable_gpu = cuda if cuda is not None else enable_gpu
113107
self.verbose = verbose
114108
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])
115109

0 commit comments

Comments
 (0)