Skip to content

Commit 39dc87f

Browse files
committed
cleaning
1 parent 63f2add commit 39dc87f

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

deepecho/models/_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
import torch
55

66

7-
def _validate_gpu_parameter(enable_gpu, cuda):
7+
def _validate_gpu_parameters(enable_gpu, cuda):
8+
"""Validate both the `enable_gpu` and `cuda` parameters.
9+
10+
The logic here is to:
11+
- Raise a warning if `cuda` is set because it's deprecated.
12+
- Raise an error if both parameters are set in a conflicting way.
13+
- Return the resolved `enable_gpu` value.
14+
"""
815
if cuda is not None:
916
warnings.warn(
1017
'`cuda` parameter is deprecated and will be removed in a future release. '
@@ -22,10 +29,8 @@ def _validate_gpu_parameter(enable_gpu, cuda):
2229
return enable_gpu
2330

2431

25-
def _set_device(enable_gpu, device=None):
26-
if device:
27-
return torch.device(device)
28-
32+
def _set_device(enable_gpu):
33+
"""Set the torch device based on the `enable_gpu` parameter and system capabilities."""
2934
if enable_gpu:
3035
if sys.platform == 'darwin': # macOS
3136
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
@@ -41,5 +46,6 @@ def _set_device(enable_gpu, device=None):
4146

4247

4348
def validate_and_set_device(enable_gpu, cuda):
44-
enable_gpu = _validate_gpu_parameter(enable_gpu, cuda)
49+
"""Validate the GPU parameters and set the torch device accordingly."""
50+
enable_gpu = _validate_gpu_parameters(enable_gpu, cuda)
4551
return _set_device(enable_gpu)

tests/unit/models/test__utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import pytest
88
import torch
99

10-
from deepecho.models._utils import _set_device, _validate_gpu_parameter, validate_and_set_device
10+
from deepecho.models._utils import _set_device, _validate_gpu_parameters, validate_and_set_device
1111

1212

13-
def test__validate_gpu_parameter():
14-
"""Test the ``_validate_gpu_parameter`` method."""
13+
def test__validate_gpu_parameterss():
14+
"""Test the ``_validate_gpu_parameters`` method."""
1515
# Setup
1616
expected_error = re.escape(
1717
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
@@ -23,13 +23,13 @@ def test__validate_gpu_parameter():
2323
)
2424

2525
# Run
26-
enable_gpu_1 = _validate_gpu_parameter(False, None)
27-
enable_gpu_2 = _validate_gpu_parameter(True, None)
26+
enable_gpu_1 = _validate_gpu_parameters(False, None)
27+
enable_gpu_2 = _validate_gpu_parameters(True, None)
2828
with pytest.warns(FutureWarning, match=expected_warning):
29-
enable_gpu_3 = _validate_gpu_parameter(True, False)
29+
enable_gpu_3 = _validate_gpu_parameters(True, False)
3030

3131
with pytest.raises(ValueError, match=expected_error):
32-
_validate_gpu_parameter(False, True)
32+
_validate_gpu_parameters(False, True)
3333

3434
# Assert
3535
assert enable_gpu_1 is False
@@ -42,8 +42,6 @@ def test__set_device():
4242
# Run
4343
device_1 = _set_device(False)
4444
device_2 = _set_device(True)
45-
device_3 = _set_device(True, 'cpu')
46-
device_4 = _set_device(enable_gpu=False, device='cpu')
4745

4846
# Assert
4947
if (
@@ -59,12 +57,10 @@ def test__set_device():
5957

6058
assert device_1 == torch.device('cpu')
6159
assert device_2 == expected_device_2
62-
assert device_3 == torch.device('cpu')
63-
assert device_4 == torch.device('cpu')
6460

6561

6662
@patch('deepecho.models._utils._set_device')
67-
@patch('deepecho.models._utils._validate_gpu_parameter')
63+
@patch('deepecho.models._utils._validate_gpu_parameters')
6864
def test_validate_and_set_device(mock_validate, mock_set_device):
6965
"""Test the ``validate_and_set_device`` method."""
7066
# Setup

0 commit comments

Comments
 (0)