Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ctgan/synthesizers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch


def _get_enable_gpu_value(enable_gpu, cuda):
def get_enable_gpu_value(enable_gpu, cuda):
"""Validate both the `enable_gpu` and `cuda` parameters.

The logic here is to:
Expand Down Expand Up @@ -53,5 +53,5 @@ def _set_device(enable_gpu, device=None):


def validate_and_set_device(enable_gpu, cuda):
enable_gpu = _get_enable_gpu_value(enable_gpu, cuda)
enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
return _set_device(enable_gpu)
14 changes: 7 additions & 7 deletions tests/unit/synthesizer/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pytest
import torch

from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device
from ctgan.synthesizers._utils import _set_device, get_enable_gpu_value, validate_and_set_device


def test__validate_gpu_parameter():
"""Test the ``_get_enable_gpu_value`` method."""
"""Test the ``get_enable_gpu_value`` method."""
# Setup
expected_error = re.escape(
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
Expand All @@ -21,13 +21,13 @@ def test__validate_gpu_parameter():
)

# Run
enable_gpu_1 = _get_enable_gpu_value(False, None)
enable_gpu_2 = _get_enable_gpu_value(True, None)
enable_gpu_1 = get_enable_gpu_value(False, None)
enable_gpu_2 = get_enable_gpu_value(True, None)
with pytest.warns(FutureWarning, match=expected_warning):
enable_gpu_3 = _get_enable_gpu_value(True, False)
enable_gpu_3 = get_enable_gpu_value(True, False)

with pytest.raises(ValueError, match=expected_error):
_get_enable_gpu_value(False, True)
get_enable_gpu_value(False, True)

# Assert
assert enable_gpu_1 is False
Expand Down Expand Up @@ -62,7 +62,7 @@ def test__set_device():


@patch('ctgan.synthesizers._utils._set_device')
@patch('ctgan.synthesizers._utils._get_enable_gpu_value')
@patch('ctgan.synthesizers._utils.get_enable_gpu_value')
def test_validate_and_set_device(mock_validate, mock_set_device):
"""Test the ``validate_and_set_device`` method."""
# Setup
Expand Down
Loading