diff --git a/ctgan/synthesizers/_utils.py b/ctgan/synthesizers/_utils.py new file mode 100644 index 00000000..20cfedfc --- /dev/null +++ b/ctgan/synthesizers/_utils.py @@ -0,0 +1,57 @@ +import platform +import warnings + +import torch + + +def _get_enable_gpu_value(enable_gpu, cuda): + """Validate both the `enable_gpu` and `cuda` parameters. + + The logic here is to: + - Raise a warning if `cuda` is set because it's deprecated. + - Raise an error if both parameters are set in a conflicting way. + - Return the resolved `enable_gpu` value. + """ + if cuda is not None: + warnings.warn( + '`cuda` parameter is deprecated and will be removed in a future release. ' + 'Please use `enable_gpu` instead.', + FutureWarning, + ) + if not enable_gpu: + raise ValueError( + 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' + 'Please use only `enable_gpu`.' + ) + + enable_gpu = cuda + + return enable_gpu + + +def _set_device(enable_gpu, device=None): + """Set the torch device based on the `enable_gpu` parameter and system capabilities.""" + if device: + return torch.device(device) + + if enable_gpu: + if platform.system() == 'Darwin': # macOS + if ( + platform.machine() == 'arm64' + and getattr(torch.backends, 'mps', None) + and torch.backends.mps.is_available() + ): + device = 'mps' + else: + device = 'cpu' + else: # Linux/Windows + device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + device = 'cpu' + + return torch.device(device) + + +def validate_and_set_device(enable_gpu, cuda): + enable_gpu = _get_enable_gpu_value(enable_gpu, cuda) + return _set_device(enable_gpu) diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py index 14b4f9c9..10bc2fe0 100644 --- a/ctgan/synthesizers/base.py +++ b/ctgan/synthesizers/base.py @@ -5,6 +5,8 @@ import numpy as np import torch +from ctgan.synthesizers._utils import _set_device + @contextlib.contextmanager def set_random_states(random_state, set_model_random_state): @@ -105,7 +107,7 @@ def __setstate__(self, state): state['random_states'] = (current_numpy_state, current_torch_state) self.__dict__ = state - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + device = _set_device(enable_gpu=True) self.set_device(device) def save(self, path): @@ -118,7 +120,7 @@ def save(self, path): @classmethod def load(cls, path): """Load the model stored in the passed `path`.""" - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + device = _set_device(enable_gpu=True) model = torch.load(path, weights_only=False) model.set_device(device) return model diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 29606a34..d9398856 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -12,6 +12,7 @@ from ctgan.data_sampler import DataSampler from ctgan.data_transformer import DataTransformer from ctgan.errors import InvalidDataError +from ctgan.synthesizers._utils import _set_device, validate_and_set_device from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -138,8 +139,11 @@ class CTGAN(BaseSynthesizer): pac (int): Number of samples to group together when applying the discriminator. Defaults to 10. + enable_gpu (bool): + Whether to attempt to use GPU for computation. + Defaults to ``True``. cuda (bool): - Whether to attempt to use cuda for GPU computation. + **Deprecated** Whether to attempt to use cuda for GPU computation. If this is False or CUDA is not available, CPU will be used. Defaults to ``True``. """ @@ -159,7 +163,8 @@ def __init__( verbose=False, epochs=300, pac=10, - cuda=True, + enable_gpu=True, + cuda=None, ): assert batch_size % 2 == 0 @@ -178,16 +183,8 @@ def __init__( self._verbose = verbose self._epochs = epochs self.pac = pac - - if not cuda or not torch.cuda.is_available(): - device = 'cpu' - elif isinstance(cuda, str): - device = cuda - else: - device = 'cuda' - - self._device = torch.device(device) - + self._device = validate_and_set_device(enable_gpu, cuda) + self._enable_gpu = cuda if cuda is not None else enable_gpu self._transformer = None self._data_sampler = None self._generator = None @@ -544,6 +541,7 @@ def sample(self, n, condition_column=None, condition_value=None): def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" - self._device = device + enable_gpu = getattr(self, '_enable_gpu', True) + self._device = _set_device(enable_gpu, device) if self._generator is not None: self._generator.to(self._device) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index ecefbb5f..30baea3d 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -10,6 +10,7 @@ from tqdm import tqdm from ctgan.data_transformer import DataTransformer +from ctgan.synthesizers._utils import _set_device, validate_and_set_device from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -114,8 +115,9 @@ def __init__( batch_size=500, epochs=300, loss_factor=2, - cuda=True, + enable_gpu=True, verbose=False, + cuda=None, ): self.embedding_dim = embedding_dim self.compress_dims = compress_dims @@ -127,15 +129,8 @@ def __init__( self.epochs = epochs self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) self.verbose = verbose - - if not cuda or not torch.cuda.is_available(): - device = 'cpu' - elif isinstance(cuda, str): - device = cuda - else: - device = 'cuda' - - self._device = torch.device(device) + self._device = validate_and_set_device(enable_gpu, cuda) + self._enable_gpu = cuda if cuda is not None else enable_gpu @random_state def fit(self, train_data, discrete_columns=()): @@ -241,6 +236,7 @@ def sample(self, samples): return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy()) def set_device(self, device): - """Set the `device` to be used ('GPU' or 'CPU).""" - self._device = device + """Set the `device` to be used ('GPU' or 'CPU').""" + enable_gpu = getattr(self, '_enable_gpu', True) + self._device = _set_device(enable_gpu, device) self.decoder.to(self._device) diff --git a/pyproject.toml b/pyproject.toml index 577626a8..2e06389d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "pandas>=2.2.3;python_version>='3.13'", "torch>=1.13.0;python_version<'3.11'", "torch>=2.0.0;python_version>='3.11' and python_version<'3.12'", - "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", + "torch>=2.3.0;python_version>='3.12' and python_version<'3.13'", "torch>=2.6.0;python_version>='3.13'", 'tqdm>=4.29,<5', 'rdt>=1.14.0', diff --git a/tests/integration/synthesizer/test_ctgan.py b/tests/integration/synthesizer/test_ctgan.py index 62881c18..eec9bbaa 100644 --- a/tests/integration/synthesizer/test_ctgan.py +++ b/tests/integration/synthesizer/test_ctgan.py @@ -236,7 +236,7 @@ def test_fixed_random_seed(): }) discrete_columns = ['discrete'] - ctgan = CTGAN(epochs=1, cuda=False) + ctgan = CTGAN(epochs=1, enable_gpu=False) # Run ctgan.fit(data, discrete_columns) diff --git a/tests/unit/synthesizer/test__utils.py b/tests/unit/synthesizer/test__utils.py new file mode 100644 index 00000000..b634fe95 --- /dev/null +++ b/tests/unit/synthesizer/test__utils.py @@ -0,0 +1,78 @@ +import platform +import re +from unittest.mock import patch + +import pytest +import torch + +from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device + + +def test__validate_gpu_parameter(): + """Test the ``_get_enable_gpu_value`` method.""" + # Setup + expected_error = re.escape( + 'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. ' + 'Please use only `enable_gpu`.' + ) + expected_warning = re.escape( + '`cuda` parameter is deprecated and will be removed in a future release. ' + 'Please use `enable_gpu` instead.' + ) + + # Run + enable_gpu_1 = _get_enable_gpu_value(False, None) + enable_gpu_2 = _get_enable_gpu_value(True, None) + with pytest.warns(FutureWarning, match=expected_warning): + enable_gpu_3 = _get_enable_gpu_value(True, False) + + with pytest.raises(ValueError, match=expected_error): + _get_enable_gpu_value(False, True) + + # Assert + assert enable_gpu_1 is False + assert enable_gpu_2 is True + assert enable_gpu_3 is False + + +def test__set_device(): + """Test the ``_set_device`` method.""" + # Run + device_1 = _set_device(False) + device_2 = _set_device(True) + device_3 = _set_device(True, 'cpu') + device_4 = _set_device(enable_gpu=False, device='cpu') + + # Assert + if ( + platform.machine() == 'arm64' + and getattr(torch.backends, 'mps', None) + and torch.backends.mps.is_available() + ): + expected_device_2 = torch.device('mps') + elif torch.cuda.is_available(): + expected_device_2 = torch.device('cuda') + else: + expected_device_2 = torch.device('cpu') + + assert device_1 == torch.device('cpu') + assert device_2 == expected_device_2 + assert device_3 == torch.device('cpu') + assert device_4 == torch.device('cpu') + + +@patch('ctgan.synthesizers._utils._set_device') +@patch('ctgan.synthesizers._utils._get_enable_gpu_value') +def test_validate_and_set_device(mock_validate, mock_set_device): + """Test the ``validate_and_set_device`` method.""" + # Setup + mock_validate.return_value = True + mock_set_device.return_value = torch.device('cuda') + + # Run + device = validate_and_set_device(enable_gpu=True, cuda=None) + + # Assert + mock_validate.assert_called_once_with(True, None) + mock_set_device.assert_called_once_with(True) + assert device == torch.device('cuda') diff --git a/tests/unit/synthesizer/test_ctgan.py b/tests/unit/synthesizer/test_ctgan.py index 8b774358..f131d983 100644 --- a/tests/unit/synthesizer/test_ctgan.py +++ b/tests/unit/synthesizer/test_ctgan.py @@ -1,7 +1,7 @@ """CTGAN unit testing module.""" from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np import pandas as pd @@ -175,6 +175,30 @@ def _assert_is_between(data, lower, upper): class TestCTGAN(TestCase): + @patch('ctgan.synthesizers.ctgan.validate_and_set_device') + def test___init__(self, mock_validate_and_set_device): + """Test the `__init__` method.""" + # Setup + mock_validate_and_set_device.return_value = 'cpu' + + # Run + synth = CTGAN() + + # Assert + assert synth._embedding_dim == 128 + assert synth._generator_dim == (256, 256) + assert synth._discriminator_dim == (256, 256) + assert synth._batch_size == 500 + assert synth._epochs == 300 + assert synth.pac == 10 + assert synth.loss_values is None + assert synth._generator is None + assert synth._data_sampler is None + assert synth._verbose is False + assert synth._enable_gpu is True + assert synth._device == 'cpu' + mock_validate_and_set_device.assert_called_once_with(True, None) + def test__apply_activate_(self): """Test `_apply_activate` for tables with both continuous and categoricals. diff --git a/tests/unit/synthesizer/test_tvae.py b/tests/unit/synthesizer/test_tvae.py index 115569eb..259b2618 100644 --- a/tests/unit/synthesizer/test_tvae.py +++ b/tests/unit/synthesizer/test_tvae.py @@ -8,6 +8,27 @@ class TestTVAE: + @patch('ctgan.synthesizers.tvae.validate_and_set_device') + def test___init__(self, mock_validate_and_set_device): + """Test the `__init__` method.""" + # Setup + mock_validate_and_set_device.return_value = 'cpu' + + # Run + synth = TVAE() + + # Assert + assert synth.embedding_dim == 128 + assert synth.compress_dims == (128, 128) + assert synth.decompress_dims == (128, 128) + assert synth.batch_size == 500 + assert synth.epochs == 300 + assert synth.loss_values.equals(pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])) + assert synth.verbose is False + assert synth._enable_gpu is True + assert synth._device == 'cpu' + mock_validate_and_set_device.assert_called_once_with(True, None) + @patch('ctgan.synthesizers.tvae._loss_function') @patch('ctgan.synthesizers.tvae.tqdm') def test_fit_verbose(self, tqdm_mock, loss_func_mock):