Skip to content

Commit 63f2add

Browse files
committed
address comments
1 parent 19be578 commit 63f2add

File tree

4 files changed

+129
-70
lines changed

4 files changed

+129
-70
lines changed

deepecho/models/_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import sys
2+
import warnings
3+
4+
import torch
5+
6+
7+
def _validate_gpu_parameter(enable_gpu, cuda):
8+
if cuda is not None:
9+
warnings.warn(
10+
'`cuda` parameter is deprecated and will be removed in a future release. '
11+
'Please use `enable_gpu` instead.',
12+
FutureWarning,
13+
)
14+
if not enable_gpu:
15+
raise ValueError(
16+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
17+
'Please use only `enable_gpu`.'
18+
)
19+
20+
enable_gpu = cuda
21+
22+
return enable_gpu
23+
24+
25+
def _set_device(enable_gpu, device=None):
26+
if device:
27+
return torch.device(device)
28+
29+
if enable_gpu:
30+
if sys.platform == 'darwin': # macOS
31+
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
32+
device = 'mps'
33+
else:
34+
device = 'cpu'
35+
else: # Linux/Windows
36+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
37+
else:
38+
device = 'cpu'
39+
40+
return torch.device(device)
41+
42+
43+
def validate_and_set_device(enable_gpu, cuda):
44+
enable_gpu = _validate_gpu_parameter(enable_gpu, cuda)
45+
return _set_device(enable_gpu)

deepecho/models/basic_gan.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,18 @@
11
"""BasicGAN Model."""
22

33
import logging
4-
import sys
5-
import warnings
64

75
import numpy as np
86
import pandas as pd
97
import torch
108
from tqdm import tqdm
119

10+
from deepecho.models._utils import validate_and_set_device
1211
from deepecho.models.base import DeepEcho
1312

1413
LOGGER = logging.getLogger(__name__)
1514

1615

17-
def _set_device(enable_gpu, cuda):
18-
if cuda is not None:
19-
if not enable_gpu:
20-
raise ValueError(
21-
'Cannot set `cuda` and `enable_gpu` together. Please use only `enable_gpu`.'
22-
)
23-
24-
warnings.warn(
25-
'`cuda` parameter is deprecated and will be removed in a future release. '
26-
'Please use `enable_gpu` instead.',
27-
FutureWarning,
28-
)
29-
enable_gpu = cuda
30-
31-
if enable_gpu:
32-
if sys.platform == 'darwin': # macOS
33-
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
34-
device = 'mps'
35-
else:
36-
device = 'cpu'
37-
else: # Linux/Windows
38-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
39-
else:
40-
device = 'cpu'
41-
42-
return torch.device(device)
43-
44-
4516
def _expand_context(data, context):
4617
return torch.cat(
4718
[
@@ -170,7 +141,7 @@ class BasicGANModel(DeepEcho):
170141
Whether to attempt to use GPU for computation.
171142
Defaults to ``True``.
172143
cuda (bool):
173-
** Deprecated ** Whether to attempt to use cuda for GPU computation.
144+
**Deprecated** Whether to attempt to use cuda for GPU computation.
174145
If this is False or CUDA is not available, CPU will be used.
175146
verbose (bool):
176147
Whether to print progress to console or not.
@@ -201,7 +172,7 @@ def __init__(
201172
self._dis_lr = dis_lr
202173
self._latent_size = latent_size
203174
self._hidden_size = hidden_size
204-
self._device = _set_device(enable_gpu, cuda)
175+
self._device = validate_and_set_device(enable_gpu, cuda)
205176
self._enable_gpu = cuda if cuda is not None else enable_gpu
206177
self._verbose = verbose
207178

tests/integration/test_basic_gan.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,7 @@
77
import pytest
88
import torch
99

10-
from deepecho.models.basic_gan import BasicGANModel, _set_device
11-
12-
13-
def test__set_device():
14-
"""Test the ``_set_device`` method."""
15-
# Setup
16-
expected_error = re.escape(
17-
'Cannot set `cuda` and `enable_gpu` together. Please use only `enable_gpu`.'
18-
)
19-
expected_warning = re.escape(
20-
'`cuda` parameter is deprecated and will be removed in a future release. '
21-
'Please use `enable_gpu` instead.'
22-
)
23-
24-
# Run
25-
device_1 = _set_device(False, None)
26-
device_2 = _set_device(True, None)
27-
with pytest.warns(FutureWarning, match=expected_warning):
28-
device_3 = _set_device(True, False)
29-
30-
with pytest.raises(ValueError, match=expected_error):
31-
_set_device(False, True)
32-
33-
# Assert
34-
if (
35-
sys.platform == 'darwin'
36-
and getattr(torch.backends, 'mps', None)
37-
and torch.backends.mps.is_available()
38-
):
39-
expected_device_2 = torch.device('mps')
40-
elif torch.cuda.is_available():
41-
expected_device_2 = torch.device('cuda')
42-
else:
43-
expected_device_2 = torch.device('cpu')
44-
45-
assert device_1 == torch.device('cpu')
46-
assert device_2 == expected_device_2
47-
assert device_3 == torch.device('cpu')
10+
from deepecho.models.basic_gan import BasicGANModel
4811

4912

5013
class TestBasicGANModel(unittest.TestCase):

tests/unit/models/test__utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Test the ``_utils`` module."""
2+
3+
import re
4+
import sys
5+
from unittest.mock import patch
6+
7+
import pytest
8+
import torch
9+
10+
from deepecho.models._utils import _set_device, _validate_gpu_parameter, validate_and_set_device
11+
12+
13+
def test__validate_gpu_parameter():
14+
"""Test the ``_validate_gpu_parameter`` method."""
15+
# Setup
16+
expected_error = re.escape(
17+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
18+
'Please use only `enable_gpu`.'
19+
)
20+
expected_warning = re.escape(
21+
'`cuda` parameter is deprecated and will be removed in a future release. '
22+
'Please use `enable_gpu` instead.'
23+
)
24+
25+
# Run
26+
enable_gpu_1 = _validate_gpu_parameter(False, None)
27+
enable_gpu_2 = _validate_gpu_parameter(True, None)
28+
with pytest.warns(FutureWarning, match=expected_warning):
29+
enable_gpu_3 = _validate_gpu_parameter(True, False)
30+
31+
with pytest.raises(ValueError, match=expected_error):
32+
_validate_gpu_parameter(False, True)
33+
34+
# Assert
35+
assert enable_gpu_1 is False
36+
assert enable_gpu_2 is True
37+
assert enable_gpu_3 is False
38+
39+
40+
def test__set_device():
41+
"""Test the ``_set_device`` method."""
42+
# Run
43+
device_1 = _set_device(False)
44+
device_2 = _set_device(True)
45+
device_3 = _set_device(True, 'cpu')
46+
device_4 = _set_device(enable_gpu=False, device='cpu')
47+
48+
# Assert
49+
if (
50+
sys.platform == 'darwin'
51+
and getattr(torch.backends, 'mps', None)
52+
and torch.backends.mps.is_available()
53+
):
54+
expected_device_2 = torch.device('mps')
55+
elif torch.cuda.is_available():
56+
expected_device_2 = torch.device('cuda')
57+
else:
58+
expected_device_2 = torch.device('cpu')
59+
60+
assert device_1 == torch.device('cpu')
61+
assert device_2 == expected_device_2
62+
assert device_3 == torch.device('cpu')
63+
assert device_4 == torch.device('cpu')
64+
65+
66+
@patch('deepecho.models._utils._set_device')
67+
@patch('deepecho.models._utils._validate_gpu_parameter')
68+
def test_validate_and_set_device(mock_validate, mock_set_device):
69+
"""Test the ``validate_and_set_device`` method."""
70+
# Setup
71+
mock_validate.return_value = True
72+
mock_set_device.return_value = torch.device('cuda')
73+
74+
# Run
75+
device = validate_and_set_device(enable_gpu=True, cuda=None)
76+
77+
# Assert
78+
mock_validate.assert_called_once_with(True, None)
79+
mock_set_device.assert_called_once_with(True)
80+
assert device == torch.device('cuda')

0 commit comments

Comments
 (0)