Skip to content

Commit c0b9495

Browse files
committed
_get_enable_gpu_value
1 parent 8647933 commit c0b9495

File tree

5 files changed

+19
-21
lines changed

5 files changed

+19
-21
lines changed

deepecho/models/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66

7-
def _validate_gpu_parameters(enable_gpu, cuda):
7+
def _get_enable_gpu_value(enable_gpu, cuda):
88
"""Validate both the `enable_gpu` and `cuda` parameters.
99
1010
The logic here is to:
@@ -48,5 +48,5 @@ def _set_device(enable_gpu):
4848

4949
def validate_and_set_device(enable_gpu, cuda):
5050
"""Validate the GPU parameters and set the torch device accordingly."""
51-
enable_gpu = _validate_gpu_parameters(enable_gpu, cuda)
51+
enable_gpu = _get_enable_gpu_value(enable_gpu, cuda)
5252
return _set_device(enable_gpu)

deepecho/models/basic_gan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ class BasicGANModel(DeepEcho):
140140
enable_gpu (bool):
141141
Whether to attempt to use GPU for computation.
142142
Defaults to ``True``.
143+
verbose (bool):
144+
Whether to print progress to console or not.
143145
cuda (bool):
144146
**Deprecated** Whether to attempt to use cuda for GPU computation.
145147
If this is False or CUDA is not available, CPU will be used.
146-
verbose (bool):
147-
Whether to print progress to console or not.
148148
"""
149149

150150
_max_sequence_length = None
@@ -173,7 +173,6 @@ def __init__(
173173
self._latent_size = latent_size
174174
self._hidden_size = hidden_size
175175
self._device = validate_and_set_device(enable_gpu, cuda)
176-
self._enable_gpu = cuda if cuda is not None else enable_gpu
177176
self._verbose = verbose
178177

179178
LOGGER.info('%s instance created', self)

deepecho/models/par.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,21 @@ class PARModel(DeepEcho):
9191
The number of times to sample (before choosing and
9292
returning the sample which maximizes the likelihood).
9393
Defaults to 1.
94-
cuda (bool):
95-
Whether to attempt to use cuda for GPU computation.
96-
If this is False or CUDA is not available, CPU will be used.
94+
enable_gpu (bool):
95+
Whether to attempt to use GPU for computation.
9796
Defaults to ``True``.
9897
verbose (bool):
9998
Whether to print progress to console or not.
99+
cuda (bool):
100+
**Deprecated** Whether to attempt to use cuda for GPU computation.
101+
If this is False or CUDA is not available, CPU will be used.
102+
Defaults to ``True``.
100103
"""
101104

102105
def __init__(self, epochs=128, sample_size=1, enable_gpu=True, verbose=True, cuda=None):
103106
self.epochs = epochs
104107
self.sample_size = sample_size
105108
self.device = validate_and_set_device(enable_gpu=enable_gpu, cuda=cuda)
106-
self._enable_gpu = cuda if cuda is not None else enable_gpu
107109
self.verbose = verbose
108110
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])
109111

tests/integration/test_basic_gan.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def test_deprecation_warning(self):
2323

2424
# Run and Assert
2525
with pytest.warns(FutureWarning, match=expected_message):
26-
model = BasicGANModel(epochs=10, cuda=False)
27-
28-
assert model._enable_gpu is False
26+
BasicGANModel(epochs=10, cuda=False)
2927

3028
def test__init___enable_gpu(self):
3129
"""Test when `enable_gpu` parameter in the constructor."""
@@ -45,7 +43,6 @@ def test__init___enable_gpu(self):
4543
expected_device = torch.device('cpu')
4644

4745
assert model._device == expected_device
48-
assert model._enable_gpu is True
4946

5047
def test_basic(self):
5148
"""Basic test for the ``BasicGANModel``."""

tests/unit/models/test__utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import pytest
88
import torch
99

10-
from deepecho.models._utils import _set_device, _validate_gpu_parameters, validate_and_set_device
10+
from deepecho.models._utils import _get_enable_gpu_value, _set_device, validate_and_set_device
1111

1212

13-
def test__validate_gpu_parameterss():
14-
"""Test the ``_validate_gpu_parameters`` method."""
13+
def test__get_enable_gpu_values():
14+
"""Test the ``_get_enable_gpu_value`` method."""
1515
# Setup
1616
expected_error = re.escape(
1717
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
@@ -23,13 +23,13 @@ def test__validate_gpu_parameterss():
2323
)
2424

2525
# Run
26-
enable_gpu_1 = _validate_gpu_parameters(False, None)
27-
enable_gpu_2 = _validate_gpu_parameters(True, None)
26+
enable_gpu_1 = _get_enable_gpu_value(False, None)
27+
enable_gpu_2 = _get_enable_gpu_value(True, None)
2828
with pytest.warns(FutureWarning, match=expected_warning):
29-
enable_gpu_3 = _validate_gpu_parameters(True, False)
29+
enable_gpu_3 = _get_enable_gpu_value(True, False)
3030

3131
with pytest.raises(ValueError, match=expected_error):
32-
_validate_gpu_parameters(False, True)
32+
_get_enable_gpu_value(False, True)
3333

3434
# Assert
3535
assert enable_gpu_1 is False
@@ -60,7 +60,7 @@ def test__set_device():
6060

6161

6262
@patch('deepecho.models._utils._set_device')
63-
@patch('deepecho.models._utils._validate_gpu_parameters')
63+
@patch('deepecho.models._utils._get_enable_gpu_value')
6464
def test_validate_and_set_device(mock_validate, mock_set_device):
6565
"""Test the ``validate_and_set_device`` method."""
6666
# Setup

0 commit comments

Comments
 (0)