From 1938413a6d52c5ad94071db31ddf32ec6c4e7892 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 23 Sep 2025 14:34:15 -0400 Subject: [PATCH 1/6] def 462 --- ctgan/synthesizers/_utils.py | 53 ++++++++++++++++++++++++++++++++++++ ctgan/synthesizers/base.py | 4 ++- ctgan/synthesizers/ctgan.py | 23 +++++++--------- ctgan/synthesizers/tvae.py | 19 +++++-------- 4 files changed, 73 insertions(+), 26 deletions(-) create mode 100644 ctgan/synthesizers/_utils.py diff --git a/ctgan/synthesizers/_utils.py b/ctgan/synthesizers/_utils.py new file mode 100644 index 00000000..daf360a8 --- /dev/null +++ b/ctgan/synthesizers/_utils.py @@ -0,0 +1,53 @@ +import sys +import warnings + +import torch + + +def _validate_gpu_parameters(enable_gpu, cuda): + """Validate both the `enable_gpu` and `cuda` parameters. + + The logic here is to: + - Raise a warning if `cuda` is set because it's deprecated. + - Raise an error if both parameters are set in a conflicting way. + - Return the resolved `enable_gpu` value. + """ + if cuda is not None: + warnings.warn( + '`cuda` parameter is deprecated and will be removed in a future release. ' + 'Please use `enable_gpu` instead.', + FutureWarning, + ) + if not enable_gpu: + raise ValueError( + 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' + 'Please use only `enable_gpu`.' + ) + + enable_gpu = cuda + + return enable_gpu + + +def _set_device(enable_gpu, device=None): + """Set the torch device based on the `enable_gpu` parameter and system capabilities.""" + if device: + return torch.device(device) + + if enable_gpu: + if sys.platform == 'darwin': # macOS + if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available(): + device = 'mps' + else: + device = 'cpu' + else: # Linux/Windows + device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + device = 'cpu' + + return torch.device(device) + + +def validate_and_set_device(enable_gpu, cuda): + enable_gpu = _validate_gpu_parameters(enable_gpu, cuda) + return _set_device(enable_gpu) diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py index 14b4f9c9..a8460b26 100644 --- a/ctgan/synthesizers/base.py +++ b/ctgan/synthesizers/base.py @@ -5,6 +5,8 @@ import numpy as np import torch +from ctgan.synthesizers._utils import _set_device + @contextlib.contextmanager def set_random_states(random_state, set_model_random_state): @@ -118,7 +120,7 @@ def save(self, path): @classmethod def load(cls, path): """Load the model stored in the passed `path`.""" - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + device = _set_device(enable_gpu=True) model = torch.load(path, weights_only=False) model.set_device(device) return model diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 29606a34..3057df9c 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -12,6 +12,7 @@ from ctgan.data_sampler import DataSampler from ctgan.data_transformer import DataTransformer from ctgan.errors import InvalidDataError +from ctgan.synthesizers._utils import _set_device, validate_and_set_device from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -138,8 +139,11 @@ class CTGAN(BaseSynthesizer): pac (int): Number of samples to group together when applying the discriminator. Defaults to 10. + enable_gpu (bool): + Whether to attempt to use GPU for computation. + Defaults to ``True``. cuda (bool): - Whether to attempt to use cuda for GPU computation. + ** Deprecated ** Whether to attempt to use cuda for GPU computation. If this is False or CUDA is not available, CPU will be used. Defaults to ``True``. """ @@ -159,7 +163,8 @@ def __init__( verbose=False, epochs=300, pac=10, - cuda=True, + enable_gpu=True, + cuda=None, ): assert batch_size % 2 == 0 @@ -178,16 +183,8 @@ def __init__( self._verbose = verbose self._epochs = epochs self.pac = pac - - if not cuda or not torch.cuda.is_available(): - device = 'cpu' - elif isinstance(cuda, str): - device = cuda - else: - device = 'cuda' - - self._device = torch.device(device) - + self._device = validate_and_set_device(enable_gpu, cuda) + self._enable_gpu = cuda if cuda is not None else enable_gpu self._transformer = None self._data_sampler = None self._generator = None @@ -544,6 +541,6 @@ def sample(self, n, condition_column=None, condition_value=None): def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" - self._device = device + self._device = _set_device(self._enable_gpu, device) if self._generator is not None: self._generator.to(self._device) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index ecefbb5f..b3ee3721 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -10,6 +10,7 @@ from tqdm import tqdm from ctgan.data_transformer import DataTransformer +from ctgan.synthesizers._utils import _set_device, validate_and_set_device from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -114,8 +115,9 @@ def __init__( batch_size=500, epochs=300, loss_factor=2, - cuda=True, + enable_gpu=True, verbose=False, + cuda=None, ): self.embedding_dim = embedding_dim self.compress_dims = compress_dims @@ -127,15 +129,8 @@ def __init__( self.epochs = epochs self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) self.verbose = verbose - - if not cuda or not torch.cuda.is_available(): - device = 'cpu' - elif isinstance(cuda, str): - device = cuda - else: - device = 'cuda' - - self._device = torch.device(device) + self._device = validate_and_set_device(enable_gpu, cuda) + self._enable_gpu = cuda if cuda is not None else enable_gpu @random_state def fit(self, train_data, discrete_columns=()): @@ -241,6 +236,6 @@ def sample(self, samples): return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy()) def set_device(self, device): - """Set the `device` to be used ('GPU' or 'CPU).""" - self._device = device + """Set the `device` to be used ('GPU' or 'CPU').""" + self._device = _set_device(self._enable_gpu, device) self.decoder.to(self._device) From d9abea1f1a16b49310d2d6a1a0284da62c23b61b Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 23 Sep 2025 14:34:30 -0400 Subject: [PATCH 2/6] tests --- tests/integration/synthesizer/test_ctgan.py | 2 +- tests/unit/synthesizer/test__utils.py | 78 +++++++++++++++++++++ tests/unit/synthesizer/test_ctgan.py | 26 ++++++- tests/unit/synthesizer/test_tvae.py | 21 ++++++ 4 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 tests/unit/synthesizer/test__utils.py diff --git a/tests/integration/synthesizer/test_ctgan.py b/tests/integration/synthesizer/test_ctgan.py index 62881c18..eec9bbaa 100644 --- a/tests/integration/synthesizer/test_ctgan.py +++ b/tests/integration/synthesizer/test_ctgan.py @@ -236,7 +236,7 @@ def test_fixed_random_seed(): }) discrete_columns = ['discrete'] - ctgan = CTGAN(epochs=1, cuda=False) + ctgan = CTGAN(epochs=1, enable_gpu=False) # Run ctgan.fit(data, discrete_columns) diff --git a/tests/unit/synthesizer/test__utils.py b/tests/unit/synthesizer/test__utils.py new file mode 100644 index 00000000..f0d8f673 --- /dev/null +++ b/tests/unit/synthesizer/test__utils.py @@ -0,0 +1,78 @@ +import re +import sys +from unittest.mock import patch + +import pytest +import torch + +from ctgan.synthesizers._utils import _set_device, _validate_gpu_parameters, validate_and_set_device + + +def test__validate_gpu_parameter(): + """Test the ``_validate_gpu_parameters`` method.""" + # Setup + expected_error = re.escape( + 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' + 'Please use only `enable_gpu`.' + ) + expected_warning = re.escape( + '`cuda` parameter is deprecated and will be removed in a future release. ' + 'Please use `enable_gpu` instead.' + ) + + # Run + enable_gpu_1 = _validate_gpu_parameters(False, None) + enable_gpu_2 = _validate_gpu_parameters(True, None) + with pytest.warns(FutureWarning, match=expected_warning): + enable_gpu_3 = _validate_gpu_parameters(True, False) + + with pytest.raises(ValueError, match=expected_error): + _validate_gpu_parameters(False, True) + + # Assert + assert enable_gpu_1 is False + assert enable_gpu_2 is True + assert enable_gpu_3 is False + + +def test__set_device(): + """Test the ``_set_device`` method.""" + # Run + device_1 = _set_device(False) + device_2 = _set_device(True) + device_3 = _set_device(True, 'cpu') + device_4 = _set_device(enable_gpu=False, device='cpu') + + # Assert + if ( + sys.platform == 'darwin' + and getattr(torch.backends, 'mps', None) + and torch.backends.mps.is_available() + ): + expected_device_2 = torch.device('mps') + elif torch.cuda.is_available(): + expected_device_2 = torch.device('cuda') + else: + expected_device_2 = torch.device('cpu') + + assert device_1 == torch.device('cpu') + assert device_2 == expected_device_2 + assert device_3 == torch.device('cpu') + assert device_4 == torch.device('cpu') + + +@patch('ctgan.synthesizers._utils._set_device') +@patch('ctgan.synthesizers._utils._validate_gpu_parameters') +def test_validate_and_set_device(mock_validate, mock_set_device): + """Test the ``validate_and_set_device`` method.""" + # Setup + mock_validate.return_value = True + mock_set_device.return_value = torch.device('cuda') + + # Run + device = validate_and_set_device(enable_gpu=True, cuda=None) + + # Assert + mock_validate.assert_called_once_with(True, None) + mock_set_device.assert_called_once_with(True) + assert device == torch.device('cuda') diff --git a/tests/unit/synthesizer/test_ctgan.py b/tests/unit/synthesizer/test_ctgan.py index 8b774358..f131d983 100644 --- a/tests/unit/synthesizer/test_ctgan.py +++ b/tests/unit/synthesizer/test_ctgan.py @@ -1,7 +1,7 @@ """CTGAN unit testing module.""" from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np import pandas as pd @@ -175,6 +175,30 @@ def _assert_is_between(data, lower, upper): class TestCTGAN(TestCase): + @patch('ctgan.synthesizers.ctgan.validate_and_set_device') + def test___init__(self, mock_validate_and_set_device): + """Test the `__init__` method.""" + # Setup + mock_validate_and_set_device.return_value = 'cpu' + + # Run + synth = CTGAN() + + # Assert + assert synth._embedding_dim == 128 + assert synth._generator_dim == (256, 256) + assert synth._discriminator_dim == (256, 256) + assert synth._batch_size == 500 + assert synth._epochs == 300 + assert synth.pac == 10 + assert synth.loss_values is None + assert synth._generator is None + assert synth._data_sampler is None + assert synth._verbose is False + assert synth._enable_gpu is True + assert synth._device == 'cpu' + mock_validate_and_set_device.assert_called_once_with(True, None) + def test__apply_activate_(self): """Test `_apply_activate` for tables with both continuous and categoricals. diff --git a/tests/unit/synthesizer/test_tvae.py b/tests/unit/synthesizer/test_tvae.py index 115569eb..259b2618 100644 --- a/tests/unit/synthesizer/test_tvae.py +++ b/tests/unit/synthesizer/test_tvae.py @@ -8,6 +8,27 @@ class TestTVAE: + @patch('ctgan.synthesizers.tvae.validate_and_set_device') + def test___init__(self, mock_validate_and_set_device): + """Test the `__init__` method.""" + # Setup + mock_validate_and_set_device.return_value = 'cpu' + + # Run + synth = TVAE() + + # Assert + assert synth.embedding_dim == 128 + assert synth.compress_dims == (128, 128) + assert synth.decompress_dims == (128, 128) + assert synth.batch_size == 500 + assert synth.epochs == 300 + assert synth.loss_values.equals(pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])) + assert synth.verbose is False + assert synth._enable_gpu is True + assert synth._device == 'cpu' + mock_validate_and_set_device.assert_called_once_with(True, None) + @patch('ctgan.synthesizers.tvae._loss_function') @patch('ctgan.synthesizers.tvae.tqdm') def test_fit_verbose(self, tqdm_mock, loss_func_mock): From ea7456f788f493864b78c373758858dd8b486498 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 23 Sep 2025 15:54:24 -0400 Subject: [PATCH 3/6] fix tests --- ctgan/synthesizers/_utils.py | 10 +++++++--- pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ctgan/synthesizers/_utils.py b/ctgan/synthesizers/_utils.py index daf360a8..56fd7b83 100644 --- a/ctgan/synthesizers/_utils.py +++ b/ctgan/synthesizers/_utils.py @@ -1,4 +1,4 @@ -import sys +import platform import warnings import torch @@ -35,8 +35,12 @@ def _set_device(enable_gpu, device=None): return torch.device(device) if enable_gpu: - if sys.platform == 'darwin': # macOS - if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available(): + if platform.system() == 'Darwin': # macOS + if ( + platform.machine() == 'arm64' + and getattr(torch.backends, 'mps', None) + and torch.backends.mps.is_available() + ): device = 'mps' else: device = 'cpu' diff --git a/pyproject.toml b/pyproject.toml index 577626a8..2e06389d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "pandas>=2.2.3;python_version>='3.13'", "torch>=1.13.0;python_version<'3.11'", "torch>=2.0.0;python_version>='3.11' and python_version<'3.12'", - "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", + "torch>=2.3.0;python_version>='3.12' and python_version<'3.13'", "torch>=2.6.0;python_version>='3.13'", 'tqdm>=4.29,<5', 'rdt>=1.14.0', From f12f1e086629aac7eed20b0132c5c8c32afe3faa Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 23 Sep 2025 16:11:01 -0400 Subject: [PATCH 4/6] fix tests 2 --- tests/unit/synthesizer/test__utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/synthesizer/test__utils.py b/tests/unit/synthesizer/test__utils.py index f0d8f673..a19bf536 100644 --- a/tests/unit/synthesizer/test__utils.py +++ b/tests/unit/synthesizer/test__utils.py @@ -1,5 +1,5 @@ +import platform import re -import sys from unittest.mock import patch import pytest @@ -45,7 +45,7 @@ def test__set_device(): # Assert if ( - sys.platform == 'darwin' + platform.machine() == 'arm64' and getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available() ): From 231dd9716e73eeca2b9020da711530f16d759636 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Wed, 24 Sep 2025 11:05:54 -0400 Subject: [PATCH 5/6] docstring --- ctgan/synthesizers/base.py | 2 +- ctgan/synthesizers/ctgan.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py index a8460b26..10bc2fe0 100644 --- a/ctgan/synthesizers/base.py +++ b/ctgan/synthesizers/base.py @@ -107,7 +107,7 @@ def __setstate__(self, state): state['random_states'] = (current_numpy_state, current_torch_state) self.__dict__ = state - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + device = _set_device(enable_gpu=True) self.set_device(device) def save(self, path): diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 3057df9c..69b99945 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -143,7 +143,7 @@ class CTGAN(BaseSynthesizer): Whether to attempt to use GPU for computation. Defaults to ``True``. cuda (bool): - ** Deprecated ** Whether to attempt to use cuda for GPU computation. + **Deprecated** Whether to attempt to use cuda for GPU computation. If this is False or CUDA is not available, CPU will be used. Defaults to ``True``. """ From 4538bce3924e90b83e6dc7dbc6ae852064141239 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 25 Sep 2025 15:37:23 -0400 Subject: [PATCH 6/6] _get_enable_gpu_value + fix backward compatibility --- ctgan/synthesizers/_utils.py | 4 ++-- ctgan/synthesizers/ctgan.py | 3 ++- ctgan/synthesizers/tvae.py | 3 ++- tests/unit/synthesizer/test__utils.py | 14 +++++++------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ctgan/synthesizers/_utils.py b/ctgan/synthesizers/_utils.py index 56fd7b83..20cfedfc 100644 --- a/ctgan/synthesizers/_utils.py +++ b/ctgan/synthesizers/_utils.py @@ -4,7 +4,7 @@ import torch -def _validate_gpu_parameters(enable_gpu, cuda): +def _get_enable_gpu_value(enable_gpu, cuda): """Validate both the `enable_gpu` and `cuda` parameters. The logic here is to: @@ -53,5 +53,5 @@ def _set_device(enable_gpu, device=None): def validate_and_set_device(enable_gpu, cuda): - enable_gpu = _validate_gpu_parameters(enable_gpu, cuda) + enable_gpu = _get_enable_gpu_value(enable_gpu, cuda) return _set_device(enable_gpu) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 69b99945..d9398856 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -541,6 +541,7 @@ def sample(self, n, condition_column=None, condition_value=None): def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" - self._device = _set_device(self._enable_gpu, device) + enable_gpu = getattr(self, '_enable_gpu', True) + self._device = _set_device(enable_gpu, device) if self._generator is not None: self._generator.to(self._device) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index b3ee3721..30baea3d 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -237,5 +237,6 @@ def sample(self, samples): def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU').""" - self._device = _set_device(self._enable_gpu, device) + enable_gpu = getattr(self, '_enable_gpu', True) + self._device = _set_device(enable_gpu, device) self.decoder.to(self._device) diff --git a/tests/unit/synthesizer/test__utils.py b/tests/unit/synthesizer/test__utils.py index a19bf536..b634fe95 100644 --- a/tests/unit/synthesizer/test__utils.py +++ b/tests/unit/synthesizer/test__utils.py @@ -5,11 +5,11 @@ import pytest import torch -from ctgan.synthesizers._utils import _set_device, _validate_gpu_parameters, validate_and_set_device +from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device def test__validate_gpu_parameter(): - """Test the ``_validate_gpu_parameters`` method.""" + """Test the ``_get_enable_gpu_value`` method.""" # Setup expected_error = re.escape( 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' @@ -21,13 +21,13 @@ def test__validate_gpu_parameter(): ) # Run - enable_gpu_1 = _validate_gpu_parameters(False, None) - enable_gpu_2 = _validate_gpu_parameters(True, None) + enable_gpu_1 = _get_enable_gpu_value(False, None) + enable_gpu_2 = _get_enable_gpu_value(True, None) with pytest.warns(FutureWarning, match=expected_warning): - enable_gpu_3 = _validate_gpu_parameters(True, False) + enable_gpu_3 = _get_enable_gpu_value(True, False) with pytest.raises(ValueError, match=expected_error): - _validate_gpu_parameters(False, True) + _get_enable_gpu_value(False, True) # Assert assert enable_gpu_1 is False @@ -62,7 +62,7 @@ def test__set_device(): @patch('ctgan.synthesizers._utils._set_device') -@patch('ctgan.synthesizers._utils._validate_gpu_parameters') +@patch('ctgan.synthesizers._utils._get_enable_gpu_value') def test_validate_and_set_device(mock_validate, mock_set_device): """Test the ``validate_and_set_device`` method.""" # Setup