55import pytest
66import torch
77
8- from ctgan .synthesizers ._utils import _get_enable_gpu_value , _set_device , validate_and_set_device
8+ from ctgan .synthesizers ._utils import _set_device , get_enable_gpu_value , validate_and_set_device
99
1010
1111def test__validate_gpu_parameter ():
12- """Test the ``_get_enable_gpu_value `` method."""
12+ """Test the ``get_enable_gpu_value `` method."""
1313 # Setup
1414 expected_error = re .escape (
1515 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
@@ -21,13 +21,13 @@ def test__validate_gpu_parameter():
2121 )
2222
2323 # Run
24- enable_gpu_1 = _get_enable_gpu_value (False , None )
25- enable_gpu_2 = _get_enable_gpu_value (True , None )
24+ enable_gpu_1 = get_enable_gpu_value (False , None )
25+ enable_gpu_2 = get_enable_gpu_value (True , None )
2626 with pytest .warns (FutureWarning , match = expected_warning ):
27- enable_gpu_3 = _get_enable_gpu_value (True , False )
27+ enable_gpu_3 = get_enable_gpu_value (True , False )
2828
2929 with pytest .raises (ValueError , match = expected_error ):
30- _get_enable_gpu_value (False , True )
30+ get_enable_gpu_value (False , True )
3131
3232 # Assert
3333 assert enable_gpu_1 is False
@@ -62,7 +62,7 @@ def test__set_device():
6262
6363
6464@patch ('ctgan.synthesizers._utils._set_device' )
65- @patch ('ctgan.synthesizers._utils._get_enable_gpu_value ' )
65+ @patch ('ctgan.synthesizers._utils.get_enable_gpu_value ' )
6666def test_validate_and_set_device (mock_validate , mock_set_device ):
6767 """Test the ``validate_and_set_device`` method."""
6868 # Setup
0 commit comments