Skip to content

Commit badd970

Browse files
authored
Make the _get_enable_gpu_value public (#467)
1 parent 60ddc47 commit badd970

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

ctgan/synthesizers/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66

7-
def _get_enable_gpu_value(enable_gpu, cuda):
7+
def get_enable_gpu_value(enable_gpu, cuda):
88
"""Validate both the `enable_gpu` and `cuda` parameters.
99
1010
The logic here is to:
@@ -53,5 +53,5 @@ def _set_device(enable_gpu, device=None):
5353

5454

5555
def validate_and_set_device(enable_gpu, cuda):
56-
enable_gpu = _get_enable_gpu_value(enable_gpu, cuda)
56+
enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
5757
return _set_device(enable_gpu)

tests/unit/synthesizer/test__utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import pytest
66
import torch
77

8-
from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device
8+
from ctgan.synthesizers._utils import _set_device, get_enable_gpu_value, validate_and_set_device
99

1010

1111
def test__validate_gpu_parameter():
12-
"""Test the ``_get_enable_gpu_value`` method."""
12+
"""Test the ``get_enable_gpu_value`` method."""
1313
# Setup
1414
expected_error = re.escape(
1515
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
@@ -21,13 +21,13 @@ def test__validate_gpu_parameter():
2121
)
2222

2323
# Run
24-
enable_gpu_1 = _get_enable_gpu_value(False, None)
25-
enable_gpu_2 = _get_enable_gpu_value(True, None)
24+
enable_gpu_1 = get_enable_gpu_value(False, None)
25+
enable_gpu_2 = get_enable_gpu_value(True, None)
2626
with pytest.warns(FutureWarning, match=expected_warning):
27-
enable_gpu_3 = _get_enable_gpu_value(True, False)
27+
enable_gpu_3 = get_enable_gpu_value(True, False)
2828

2929
with pytest.raises(ValueError, match=expected_error):
30-
_get_enable_gpu_value(False, True)
30+
get_enable_gpu_value(False, True)
3131

3232
# Assert
3333
assert enable_gpu_1 is False
@@ -62,7 +62,7 @@ def test__set_device():
6262

6363

6464
@patch('ctgan.synthesizers._utils._set_device')
65-
@patch('ctgan.synthesizers._utils._get_enable_gpu_value')
65+
@patch('ctgan.synthesizers._utils.get_enable_gpu_value')
6666
def test_validate_and_set_device(mock_validate, mock_set_device):
6767
"""Test the ``validate_and_set_device`` method."""
6868
# Setup

0 commit comments

Comments
 (0)