Skip to content

Commit 4538bce

Browse files
committed
_get_enable_gpu_value + fix backward compatibility
1 parent 231dd97 commit 4538bce

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

ctgan/synthesizers/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66

7-
def _validate_gpu_parameters(enable_gpu, cuda):
7+
def _get_enable_gpu_value(enable_gpu, cuda):
88
"""Validate both the `enable_gpu` and `cuda` parameters.
99
1010
The logic here is to:
@@ -53,5 +53,5 @@ def _set_device(enable_gpu, device=None):
5353

5454

5555
def validate_and_set_device(enable_gpu, cuda):
56-
enable_gpu = _validate_gpu_parameters(enable_gpu, cuda)
56+
enable_gpu = _get_enable_gpu_value(enable_gpu, cuda)
5757
return _set_device(enable_gpu)

ctgan/synthesizers/ctgan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def sample(self, n, condition_column=None, condition_value=None):
541541

542542
def set_device(self, device):
543543
"""Set the `device` to be used ('GPU' or 'CPU)."""
544-
self._device = _set_device(self._enable_gpu, device)
544+
enable_gpu = getattr(self, '_enable_gpu', True)
545+
self._device = _set_device(enable_gpu, device)
545546
if self._generator is not None:
546547
self._generator.to(self._device)

ctgan/synthesizers/tvae.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,5 +237,6 @@ def sample(self, samples):
237237

238238
def set_device(self, device):
239239
"""Set the `device` to be used ('GPU' or 'CPU')."""
240-
self._device = _set_device(self._enable_gpu, device)
240+
enable_gpu = getattr(self, '_enable_gpu', True)
241+
self._device = _set_device(enable_gpu, device)
241242
self.decoder.to(self._device)

tests/unit/synthesizer/test__utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import pytest
66
import torch
77

8-
from ctgan.synthesizers._utils import _set_device, _validate_gpu_parameters, validate_and_set_device
8+
from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device
99

1010

1111
def test__validate_gpu_parameter():
12-
"""Test the ``_validate_gpu_parameters`` method."""
12+
"""Test the ``_get_enable_gpu_value`` method."""
1313
# Setup
1414
expected_error = re.escape(
1515
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
@@ -21,13 +21,13 @@ def test__validate_gpu_parameter():
2121
)
2222

2323
# Run
24-
enable_gpu_1 = _validate_gpu_parameters(False, None)
25-
enable_gpu_2 = _validate_gpu_parameters(True, None)
24+
enable_gpu_1 = _get_enable_gpu_value(False, None)
25+
enable_gpu_2 = _get_enable_gpu_value(True, None)
2626
with pytest.warns(FutureWarning, match=expected_warning):
27-
enable_gpu_3 = _validate_gpu_parameters(True, False)
27+
enable_gpu_3 = _get_enable_gpu_value(True, False)
2828

2929
with pytest.raises(ValueError, match=expected_error):
30-
_validate_gpu_parameters(False, True)
30+
_get_enable_gpu_value(False, True)
3131

3232
# Assert
3333
assert enable_gpu_1 is False
@@ -62,7 +62,7 @@ def test__set_device():
6262

6363

6464
@patch('ctgan.synthesizers._utils._set_device')
65-
@patch('ctgan.synthesizers._utils._validate_gpu_parameters')
65+
@patch('ctgan.synthesizers._utils._get_enable_gpu_value')
6666
def test_validate_and_set_device(mock_validate, mock_set_device):
6767
"""Test the ``validate_and_set_device`` method."""
6868
# Setup

0 commit comments

Comments
 (0)