Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions ctgan/synthesizers/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import platform
import warnings

import torch


def _get_enable_gpu_value(enable_gpu, cuda):
"""Validate both the `enable_gpu` and `cuda` parameters.

The logic here is to:
- Raise a warning if `cuda` is set because it's deprecated.
- Raise an error if both parameters are set in a conflicting way.
- Return the resolved `enable_gpu` value.
"""
if cuda is not None:
warnings.warn(
'`cuda` parameter is deprecated and will be removed in a future release. '
'Please use `enable_gpu` instead.',
FutureWarning,
)
if not enable_gpu:
raise ValueError(
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
'Please use only `enable_gpu`.'
)

enable_gpu = cuda

return enable_gpu


def _set_device(enable_gpu, device=None):
"""Set the torch device based on the `enable_gpu` parameter and system capabilities."""
if device:
return torch.device(device)

if enable_gpu:
if platform.system() == 'Darwin': # macOS
if (
platform.machine() == 'arm64'
and getattr(torch.backends, 'mps', None)
and torch.backends.mps.is_available()
):
device = 'mps'
else:
device = 'cpu'
else: # Linux/Windows
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = 'cpu'

return torch.device(device)


def validate_and_set_device(enable_gpu, cuda):
enable_gpu = _get_enable_gpu_value(enable_gpu, cuda)
return _set_device(enable_gpu)
6 changes: 4 additions & 2 deletions ctgan/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import torch

from ctgan.synthesizers._utils import _set_device


@contextlib.contextmanager
def set_random_states(random_state, set_model_random_state):
Expand Down Expand Up @@ -105,7 +107,7 @@ def __setstate__(self, state):
state['random_states'] = (current_numpy_state, current_torch_state)

self.__dict__ = state
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = _set_device(enable_gpu=True)
self.set_device(device)

def save(self, path):
Expand All @@ -118,7 +120,7 @@ def save(self, path):
@classmethod
def load(cls, path):
"""Load the model stored in the passed `path`."""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = _set_device(enable_gpu=True)
model = torch.load(path, weights_only=False)
model.set_device(device)
return model
Expand Down
24 changes: 11 additions & 13 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ctgan.data_sampler import DataSampler
from ctgan.data_transformer import DataTransformer
from ctgan.errors import InvalidDataError
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
from ctgan.synthesizers.base import BaseSynthesizer, random_state


Expand Down Expand Up @@ -138,8 +139,11 @@ class CTGAN(BaseSynthesizer):
pac (int):
Number of samples to group together when applying the discriminator.
Defaults to 10.
enable_gpu (bool):
Whether to attempt to use GPU for computation.
Defaults to ``True``.
cuda (bool):
Whether to attempt to use cuda for GPU computation.
**Deprecated** Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
"""
Expand All @@ -159,7 +163,8 @@ def __init__(
verbose=False,
epochs=300,
pac=10,
cuda=True,
enable_gpu=True,
cuda=None,
):
assert batch_size % 2 == 0

Expand All @@ -178,16 +183,8 @@ def __init__(
self._verbose = verbose
self._epochs = epochs
self.pac = pac

if not cuda or not torch.cuda.is_available():
device = 'cpu'
elif isinstance(cuda, str):
device = cuda
else:
device = 'cuda'

self._device = torch.device(device)

self._device = validate_and_set_device(enable_gpu, cuda)
self._enable_gpu = cuda if cuda is not None else enable_gpu
self._transformer = None
self._data_sampler = None
self._generator = None
Expand Down Expand Up @@ -544,6 +541,7 @@ def sample(self, n, condition_column=None, condition_value=None):

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU)."""
self._device = device
enable_gpu = getattr(self, '_enable_gpu', True)
self._device = _set_device(enable_gpu, device)
if self._generator is not None:
self._generator.to(self._device)
20 changes: 8 additions & 12 deletions ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tqdm import tqdm

from ctgan.data_transformer import DataTransformer
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
from ctgan.synthesizers.base import BaseSynthesizer, random_state


Expand Down Expand Up @@ -114,8 +115,9 @@ def __init__(
batch_size=500,
epochs=300,
loss_factor=2,
cuda=True,
enable_gpu=True,
verbose=False,
cuda=None,
):
self.embedding_dim = embedding_dim
self.compress_dims = compress_dims
Expand All @@ -127,15 +129,8 @@ def __init__(
self.epochs = epochs
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
self.verbose = verbose

if not cuda or not torch.cuda.is_available():
device = 'cpu'
elif isinstance(cuda, str):
device = cuda
else:
device = 'cuda'

self._device = torch.device(device)
self._device = validate_and_set_device(enable_gpu, cuda)
self._enable_gpu = cuda if cuda is not None else enable_gpu

@random_state
def fit(self, train_data, discrete_columns=()):
Expand Down Expand Up @@ -241,6 +236,7 @@ def sample(self, samples):
return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU)."""
self._device = device
"""Set the `device` to be used ('GPU' or 'CPU')."""
enable_gpu = getattr(self, '_enable_gpu', True)
self._device = _set_device(enable_gpu, device)
self.decoder.to(self._device)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"pandas>=2.2.3;python_version>='3.13'",
"torch>=1.13.0;python_version<'3.11'",
"torch>=2.0.0;python_version>='3.11' and python_version<'3.12'",
"torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
"torch>=2.3.0;python_version>='3.12' and python_version<'3.13'",
"torch>=2.6.0;python_version>='3.13'",
'tqdm>=4.29,<5',
'rdt>=1.14.0',
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_fixed_random_seed():
})
discrete_columns = ['discrete']

ctgan = CTGAN(epochs=1, cuda=False)
ctgan = CTGAN(epochs=1, enable_gpu=False)

# Run
ctgan.fit(data, discrete_columns)
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/synthesizer/test__utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import platform
import re
from unittest.mock import patch

import pytest
import torch

from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device


def test__validate_gpu_parameter():
"""Test the ``_get_enable_gpu_value`` method."""
# Setup
expected_error = re.escape(
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
'Please use only `enable_gpu`.'
)
expected_warning = re.escape(
'`cuda` parameter is deprecated and will be removed in a future release. '
'Please use `enable_gpu` instead.'
)

# Run
enable_gpu_1 = _get_enable_gpu_value(False, None)
enable_gpu_2 = _get_enable_gpu_value(True, None)
with pytest.warns(FutureWarning, match=expected_warning):
enable_gpu_3 = _get_enable_gpu_value(True, False)

with pytest.raises(ValueError, match=expected_error):
_get_enable_gpu_value(False, True)

# Assert
assert enable_gpu_1 is False
assert enable_gpu_2 is True
assert enable_gpu_3 is False


def test__set_device():
"""Test the ``_set_device`` method."""
# Run
device_1 = _set_device(False)
device_2 = _set_device(True)
device_3 = _set_device(True, 'cpu')
device_4 = _set_device(enable_gpu=False, device='cpu')

# Assert
if (
platform.machine() == 'arm64'
and getattr(torch.backends, 'mps', None)
and torch.backends.mps.is_available()
):
expected_device_2 = torch.device('mps')
elif torch.cuda.is_available():
expected_device_2 = torch.device('cuda')
else:
expected_device_2 = torch.device('cpu')

assert device_1 == torch.device('cpu')
assert device_2 == expected_device_2
assert device_3 == torch.device('cpu')
assert device_4 == torch.device('cpu')


@patch('ctgan.synthesizers._utils._set_device')
@patch('ctgan.synthesizers._utils._get_enable_gpu_value')
def test_validate_and_set_device(mock_validate, mock_set_device):
"""Test the ``validate_and_set_device`` method."""
# Setup
mock_validate.return_value = True
mock_set_device.return_value = torch.device('cuda')

# Run
device = validate_and_set_device(enable_gpu=True, cuda=None)

# Assert
mock_validate.assert_called_once_with(True, None)
mock_set_device.assert_called_once_with(True)
assert device == torch.device('cuda')
26 changes: 25 additions & 1 deletion tests/unit/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CTGAN unit testing module."""

from unittest import TestCase
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -175,6 +175,30 @@ def _assert_is_between(data, lower, upper):


class TestCTGAN(TestCase):
@patch('ctgan.synthesizers.ctgan.validate_and_set_device')
def test___init__(self, mock_validate_and_set_device):
"""Test the `__init__` method."""
# Setup
mock_validate_and_set_device.return_value = 'cpu'

# Run
synth = CTGAN()

# Assert
assert synth._embedding_dim == 128
assert synth._generator_dim == (256, 256)
assert synth._discriminator_dim == (256, 256)
assert synth._batch_size == 500
assert synth._epochs == 300
assert synth.pac == 10
assert synth.loss_values is None
assert synth._generator is None
assert synth._data_sampler is None
assert synth._verbose is False
assert synth._enable_gpu is True
assert synth._device == 'cpu'
mock_validate_and_set_device.assert_called_once_with(True, None)

def test__apply_activate_(self):
"""Test `_apply_activate` for tables with both continuous and categoricals.

Expand Down
21 changes: 21 additions & 0 deletions tests/unit/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@


class TestTVAE:
@patch('ctgan.synthesizers.tvae.validate_and_set_device')
def test___init__(self, mock_validate_and_set_device):
"""Test the `__init__` method."""
# Setup
mock_validate_and_set_device.return_value = 'cpu'

# Run
synth = TVAE()

# Assert
assert synth.embedding_dim == 128
assert synth.compress_dims == (128, 128)
assert synth.decompress_dims == (128, 128)
assert synth.batch_size == 500
assert synth.epochs == 300
assert synth.loss_values.equals(pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']))
assert synth.verbose is False
assert synth._enable_gpu is True
assert synth._device == 'cpu'
mock_validate_and_set_device.assert_called_once_with(True, None)

@patch('ctgan.synthesizers.tvae._loss_function')
@patch('ctgan.synthesizers.tvae.tqdm')
def test_fit_verbose(self, tqdm_mock, loss_func_mock):
Expand Down
Loading