Skip to content

Commit 19be578

Browse files
committed
define _set_device()
1 parent 48b0a2f commit 19be578

File tree

2 files changed

+69
-24
lines changed

2 files changed

+69
-24
lines changed

deepecho/models/basic_gan.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@
1414
LOGGER = logging.getLogger(__name__)
1515

1616

17+
def _set_device(enable_gpu, cuda):
18+
if cuda is not None:
19+
if not enable_gpu:
20+
raise ValueError(
21+
'Cannot set `cuda` and `enable_gpu` together. Please use only `enable_gpu`.'
22+
)
23+
24+
warnings.warn(
25+
'`cuda` parameter is deprecated and will be removed in a future release. '
26+
'Please use `enable_gpu` instead.',
27+
FutureWarning,
28+
)
29+
enable_gpu = cuda
30+
31+
if enable_gpu:
32+
if sys.platform == 'darwin': # macOS
33+
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
34+
device = 'mps'
35+
else:
36+
device = 'cpu'
37+
else: # Linux/Windows
38+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
39+
else:
40+
device = 'cpu'
41+
42+
return torch.device(device)
43+
44+
1745
def _expand_context(data, context):
1846
return torch.cat(
1947
[
@@ -165,36 +193,16 @@ def __init__(
165193
gen_lr=1e-3,
166194
dis_lr=1e-3,
167195
enable_gpu=True,
168-
cuda=None,
169196
verbose=True,
197+
cuda=None,
170198
):
171-
if cuda is not None:
172-
warnings.warn(
173-
'`cuda` parameter is deprecated and will be removed in a future release. '
174-
'Please use `enable_gpu` instead.',
175-
FutureWarning,
176-
)
177-
enable_gpu = cuda
178-
179199
self._epochs = epochs
180200
self._gen_lr = gen_lr
181201
self._dis_lr = dis_lr
182202
self._latent_size = latent_size
183203
self._hidden_size = hidden_size
184-
self._enable_gpu = enable_gpu
185-
186-
if enable_gpu:
187-
if sys.platform == 'darwin': # macOS
188-
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
189-
device = 'mps'
190-
else:
191-
device = 'cpu'
192-
else: # Linux/Windows
193-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
194-
else:
195-
device = 'cpu'
196-
197-
self._device = torch.device(device)
204+
self._device = _set_device(enable_gpu, cuda)
205+
self._enable_gpu = cuda if cuda is not None else enable_gpu
198206
self._verbose = verbose
199207

200208
LOGGER.info('%s instance created', self)

tests/integration/test_basic_gan.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,44 @@
77
import pytest
88
import torch
99

10-
from deepecho.models.basic_gan import BasicGANModel
10+
from deepecho.models.basic_gan import BasicGANModel, _set_device
11+
12+
13+
def test__set_device():
14+
"""Test the ``_set_device`` method."""
15+
# Setup
16+
expected_error = re.escape(
17+
'Cannot set `cuda` and `enable_gpu` together. Please use only `enable_gpu`.'
18+
)
19+
expected_warning = re.escape(
20+
'`cuda` parameter is deprecated and will be removed in a future release. '
21+
'Please use `enable_gpu` instead.'
22+
)
23+
24+
# Run
25+
device_1 = _set_device(False, None)
26+
device_2 = _set_device(True, None)
27+
with pytest.warns(FutureWarning, match=expected_warning):
28+
device_3 = _set_device(True, False)
29+
30+
with pytest.raises(ValueError, match=expected_error):
31+
_set_device(False, True)
32+
33+
# Assert
34+
if (
35+
sys.platform == 'darwin'
36+
and getattr(torch.backends, 'mps', None)
37+
and torch.backends.mps.is_available()
38+
):
39+
expected_device_2 = torch.device('mps')
40+
elif torch.cuda.is_available():
41+
expected_device_2 = torch.device('cuda')
42+
else:
43+
expected_device_2 = torch.device('cpu')
44+
45+
assert device_1 == torch.device('cpu')
46+
assert device_2 == expected_device_2
47+
assert device_3 == torch.device('cpu')
1148

1249

1350
class TestBasicGANModel(unittest.TestCase):

0 commit comments

Comments
 (0)