Skip to content

Commit 2bd8841

Browse files
committed
When using PyTorch, enable GPU usage for MacOS (#2680)
1 parent 526cd12 commit 2bd8841

File tree

9 files changed

+165
-34
lines changed

9 files changed

+165
-34
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
"pandas>=2.2.3;python_version>='3.13'",
3636
'tqdm>=4.29',
3737
'copulas>=0.12.1',
38-
'ctgan>=0.11.0',
38+
'ctgan @ git+https://github.com/sdv-dev/CTGAN.git@main',
3939
'deepecho>=0.7.0',
4040
'rdt>=1.18.2',
4141
'sdmetrics>=0.21.0',

sdv/single_table/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,23 @@ def update_transformers(self, column_name_to_transformer):
309309
msg = 'For this change to take effect, please refit the synthesizer using `fit`.'
310310
warnings.warn(msg, RefitWarning)
311311

312+
def _resolve_gpu_parameters(self, parameters):
313+
if parameters.get('cuda') is not None and parameters.get('enable_gpu') is None:
314+
parameters.pop('enable_gpu', None) # Ensure backward-compatibilty
315+
elif 'cuda' in parameters: # Removed because deprecated
316+
del parameters['cuda']
317+
318+
return parameters
319+
312320
def get_parameters(self):
313321
"""Return the parameters used to instantiate the synthesizer."""
314322
parameters = inspect.signature(self.__init__).parameters
315323
instantiated_parameters = {}
316324
for parameter_name in parameters:
317-
if parameter_name != 'metadata':
325+
if parameter_name not in ['metadata']:
318326
instantiated_parameters[parameter_name] = self.__dict__.get(parameter_name)
319327

320-
return instantiated_parameters
328+
return self._resolve_gpu_parameters(instantiated_parameters)
321329

322330
def get_metadata(self, version='original'):
323331
"""Get the metadata, either original or modified after applying constraints.

sdv/single_table/copulagan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ class CopulaGANSynthesizer(CTGANSynthesizer):
8989
Whether to print fit progress on stdout. Defaults to ``False``.
9090
epochs (int):
9191
Number of training epochs. Defaults to 300.
92+
enable_gpu (bool):
93+
Whether to attempt to use GPU for computation.
94+
Defaults to ``True``.
9295
cuda (bool or str):
96+
**Deprecated**
9397
If ``True``, use CUDA. If an ``str``, use the indicated device.
9498
If ``False``, do not use cuda at all.
9599
numerical_distributions (dict):
@@ -139,9 +143,10 @@ def __init__(
139143
verbose=False,
140144
epochs=300,
141145
pac=10,
142-
cuda=True,
146+
enable_gpu=True,
143147
numerical_distributions=None,
144148
default_distribution=None,
149+
cuda=None,
145150
):
146151
super().__init__(
147152
metadata,
@@ -161,6 +166,7 @@ def __init__(
161166
verbose=verbose,
162167
epochs=epochs,
163168
pac=pac,
169+
enable_gpu=enable_gpu,
164170
cuda=cuda,
165171
)
166172

sdv/single_table/ctgan.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
try:
1616
from ctgan import CTGAN, TVAE
17+
from ctgan.synthesizers._utils import get_enable_gpu_value
1718

1819
import_error = None
1920
except ModuleNotFoundError as e:
@@ -154,7 +155,11 @@ class CTGANSynthesizer(LossValuesMixin, MissingModuleMixin, BaseSingleTableSynth
154155
pac (int):
155156
Number of samples to group together when applying the discriminator.
156157
Defaults to 10.
158+
enable_gpu (bool):
159+
Whether to attempt to use GPU for computation.
160+
Defaults to ``True``.
157161
cuda (bool or str):
162+
**Deprecated**
158163
If ``True``, use CUDA. If a ``str``, use the indicated device.
159164
If ``False``, do not use cuda at all.
160165
"""
@@ -180,7 +185,8 @@ def __init__(
180185
verbose=False,
181186
epochs=300,
182187
pac=10,
183-
cuda=True,
188+
enable_gpu=True,
189+
cuda=None,
184190
):
185191
if CTGAN is None:
186192
self.raise_module_not_found_error(import_error)
@@ -204,8 +210,7 @@ def __init__(
204210
self.verbose = verbose
205211
self.epochs = epochs
206212
self.pac = pac
207-
self.cuda = cuda
208-
213+
self.enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
209214
self._model_kwargs = {
210215
'embedding_dim': embedding_dim,
211216
'generator_dim': generator_dim,
@@ -220,7 +225,7 @@ def __init__(
220225
'verbose': verbose,
221226
'epochs': epochs,
222227
'pac': pac,
223-
'cuda': cuda,
228+
'enable_gpu': self.enable_gpu,
224229
}
225230

226231
def _estimate_num_columns(self, data):
@@ -353,7 +358,11 @@ class TVAESynthesizer(LossValuesMixin, MissingModuleMixin, BaseSingleTableSynthe
353358
Number of training epochs. Defaults to 300.
354359
loss_factor (int):
355360
Multiplier for the reconstruction error. Defaults to 2.
361+
enable_gpu (bool):
362+
Whether to attempt to use GPU for computation.
363+
Defaults to ``True``.
356364
cuda (bool or str):
365+
**Deprecated**
357366
If ``True``, use CUDA. If a ``str``, use the indicated device.
358367
If ``False``, do not use cuda at all.
359368
"""
@@ -373,7 +382,8 @@ def __init__(
373382
verbose=False,
374383
epochs=300,
375384
loss_factor=2,
376-
cuda=True,
385+
enable_gpu=True,
386+
cuda=None,
377387
):
378388
if TVAE is None:
379389
self.raise_module_not_found_error(import_error)
@@ -390,8 +400,7 @@ def __init__(
390400
self.verbose = verbose
391401
self.epochs = epochs
392402
self.loss_factor = loss_factor
393-
self.cuda = cuda
394-
403+
self.enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
395404
self._model_kwargs = {
396405
'embedding_dim': embedding_dim,
397406
'compress_dims': compress_dims,
@@ -401,7 +410,7 @@ def __init__(
401410
'verbose': verbose,
402411
'epochs': epochs,
403412
'loss_factor': loss_factor,
404-
'cuda': cuda,
413+
'enable_gpu': self.enable_gpu,
405414
}
406415

407416
def _fit(self, processed_data):

tests/integration/single_table/test_base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
})
3535

3636
SYNTHESIZERS = [
37-
pytest.param(CTGANSynthesizer(METADATA, epochs=1, cuda=False), id='CTGANSynthesizer'),
38-
pytest.param(TVAESynthesizer(METADATA, epochs=1, cuda=False), id='TVAESynthesizer'),
37+
pytest.param(CTGANSynthesizer(METADATA, epochs=1, enable_gpu=False), id='CTGANSynthesizer'),
38+
pytest.param(TVAESynthesizer(METADATA, epochs=1, enable_gpu=False), id='TVAESynthesizer'),
3939
pytest.param(GaussianCopulaSynthesizer(METADATA), id='GaussianCopulaSynthesizer'),
40-
pytest.param(CopulaGANSynthesizer(METADATA, epochs=1, cuda=False), id='CopulaGANSynthesizer'),
40+
pytest.param(
41+
CopulaGANSynthesizer(METADATA, epochs=1, enable_gpu=False), id='CopulaGANSynthesizer'
42+
),
4143
]
4244

4345

@@ -270,7 +272,7 @@ def test_sampling_reset_sampling(synthesizer):
270272
})
271273

272274
if isinstance(synthesizer, (CTGANSynthesizer, TVAESynthesizer)):
273-
synthesizer = synthesizer.__class__(metadata, cuda=False)
275+
synthesizer = synthesizer.__class__(metadata, enable_gpu=False)
274276
else:
275277
synthesizer = synthesizer.__class__(metadata)
276278

tests/integration/single_table/test_ctgan.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
import platform
12
import re
23
from unittest.mock import patch
34

45
import numpy as np
56
import pandas as pd
67
import pytest
8+
import torch
79
from rdt.transformers import FloatFormatter, LabelEncoder
810

911
from sdv.cag import FixedCombinations
1012
from sdv.datasets.demo import download_demo
1113
from sdv.errors import InvalidDataTypeError
1214
from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot
1315
from sdv.metadata.metadata import Metadata
14-
from sdv.single_table import CTGANSynthesizer, TVAESynthesizer
16+
from sdv.single_table import CopulaGANSynthesizer, CTGANSynthesizer, TVAESynthesizer
1517

1618

1719
def test__estimate_num_columns():
@@ -331,3 +333,61 @@ def test_tvae___init___without_torch(mock_import_error):
331333
# Run and Assert
332334
with pytest.raises(ModuleNotFoundError, match=msg):
333335
TVAESynthesizer(metadata)
336+
337+
338+
@pytest.mark.parametrize(
339+
'synthesizer_class', [CTGANSynthesizer, TVAESynthesizer, CopulaGANSynthesizer]
340+
)
341+
def test_enable_gpu_parameter(synthesizer_class):
342+
"""Test that the `enable_gpu` parameter is correctly passed to the underlying model."""
343+
# Setup
344+
data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')
345+
expected_warning = re.escape(
346+
'`cuda` parameter is deprecated and will be removed in a future release. '
347+
'Please use `enable_gpu` instead.'
348+
)
349+
expected_error = re.escape(
350+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
351+
'Please use only `enable_gpu`.'
352+
)
353+
354+
# Run
355+
synthesizer_1 = synthesizer_class(metadata)
356+
synthesizer_2 = synthesizer_class(metadata, enable_gpu=False)
357+
with pytest.warns(FutureWarning, match=expected_warning):
358+
synthesizer_3 = synthesizer_class(metadata, cuda=True)
359+
360+
with pytest.raises(ValueError, match=expected_error):
361+
synthesizer_class(metadata, enable_gpu=False, cuda=True)
362+
363+
synthesizer_1.fit(data)
364+
synthesizer_2.fit(data)
365+
synthesizer_3.fit(data)
366+
synthetic_data_1 = synthesizer_1.sample(10)
367+
synthetic_data_2 = synthesizer_2.sample(10)
368+
synthetic_data_3 = synthesizer_3.sample(10)
369+
370+
# Assert
371+
data_columns = data.columns.tolist()
372+
if (
373+
platform.machine() == 'arm64'
374+
and getattr(torch.backends, 'mps', None)
375+
and torch.backends.mps.is_available()
376+
):
377+
expected_device = torch.device('mps')
378+
elif torch.cuda.is_available():
379+
expected_device = torch.device('cuda')
380+
else:
381+
expected_device = torch.device('cpu')
382+
383+
assert synthesizer_1._model._enable_gpu is True
384+
assert synthesizer_1._model._device == expected_device
385+
assert synthesizer_2._model._enable_gpu is False
386+
assert synthesizer_2._model._device == torch.device('cpu')
387+
assert synthesizer_3._model._enable_gpu is True
388+
assert synthesizer_3._model._device == expected_device
389+
assert synthetic_data_1.columns.tolist() == data_columns
390+
assert synthetic_data_2.columns.tolist() == data_columns
391+
assert synthetic_data_3.columns.tolist() == data_columns
392+
assert len(synthetic_data_1) == 10
393+
assert len(synthetic_data_2) == len(synthetic_data_3) == 10

tests/unit/single_table/test_base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,52 @@ def test_set_address_columns_warning(self):
369369
['country_column', 'city_column'], anonymization_level='full'
370370
)
371371

372+
def test__resolve_gpu_parameters(self):
373+
"""Test the `_resolve_gpu_parameters` method."""
374+
# Setup
375+
metadata = Metadata()
376+
instance = BaseSingleTableSynthesizer(metadata)
377+
parameters_with_cuda = {'cuda': True, 'enable_gpu': True}
378+
parameters_with_cuda_only = {'cuda': True}
379+
parameters_with_cuda_none = {'cuda': None, 'enable_gpu': True}
380+
parameters_without_cuda = {'enable_gpu': False}
381+
382+
# Run
383+
result_with_cuda = instance._resolve_gpu_parameters(parameters_with_cuda)
384+
result_with_cuda_only = instance._resolve_gpu_parameters(parameters_with_cuda_only)
385+
result_with_cuda_none = instance._resolve_gpu_parameters(parameters_with_cuda_none)
386+
result_without_cuda = instance._resolve_gpu_parameters(parameters_without_cuda)
387+
388+
# Assert
389+
assert result_with_cuda == {'cuda': True, 'enable_gpu': True}
390+
assert result_with_cuda_only == {'cuda': True}
391+
assert result_with_cuda_none == {'enable_gpu': True}
392+
assert result_without_cuda == {'enable_gpu': False}
393+
394+
def test_get_parameters_mock(self):
395+
"""Test that `get_parameters` calls `_resolve_gpu_parameters`"""
396+
metadata = Metadata()
397+
instance = BaseSynthesizer(
398+
metadata, enforce_min_max_values=False, enforce_rounding=False, locales='en_CA'
399+
)
400+
expected_parameters = {
401+
'enforce_min_max_values': False,
402+
'enforce_rounding': True,
403+
'locales': 'en_CA',
404+
}
405+
instance._resolve_gpu_parameters = Mock(return_value=expected_parameters)
406+
407+
# Run
408+
parameters = instance.get_parameters()
409+
410+
# Assert
411+
assert parameters == expected_parameters
412+
instance._resolve_gpu_parameters.assert_called_once_with({
413+
'enforce_min_max_values': False,
414+
'enforce_rounding': False,
415+
'locales': 'en_CA',
416+
})
417+
372418
def test_get_parameters(self):
373419
"""Test that it returns every ``init`` parameter without the ``metadata``."""
374420
# Setup

tests/unit/single_table/test_copulagan.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test___init__(self):
4343
assert instance.verbose is False
4444
assert instance.epochs == 300
4545
assert instance.pac == 10
46-
assert instance.cuda is True
46+
assert instance.enable_gpu is True
4747
assert instance.numerical_distributions == {}
4848
assert instance.default_distribution == 'beta'
4949
assert instance._numerical_distributions == {}
@@ -79,7 +79,7 @@ def test___init__with_unified_metadata(self):
7979
assert instance.verbose is False
8080
assert instance.epochs == 300
8181
assert instance.pac == 10
82-
assert instance.cuda is True
82+
assert instance.enable_gpu is True
8383
assert instance.numerical_distributions == {}
8484
assert instance.default_distribution == 'beta'
8585
assert instance._numerical_distributions == {}
@@ -128,7 +128,7 @@ def test___init__custom(self):
128128
verbose=verbose,
129129
epochs=epochs,
130130
pac=pac,
131-
cuda=cuda,
131+
enable_gpu=cuda,
132132
numerical_distributions=numerical_distributions,
133133
default_distribution=default_distribution,
134134
)
@@ -149,7 +149,7 @@ def test___init__custom(self):
149149
assert instance.verbose is True
150150
assert instance.epochs == epochs
151151
assert instance.pac == pac
152-
assert instance.cuda is False
152+
assert instance.enable_gpu is False
153153
assert instance.numerical_distributions == {'field': 'gamma'}
154154
assert instance._numerical_distributions == {'field': GammaUnivariate}
155155
assert instance.default_distribution == 'uniform'
@@ -208,7 +208,7 @@ def test_get_params(self):
208208
'verbose': False,
209209
'epochs': 300,
210210
'pac': 10,
211-
'cuda': True,
211+
'enable_gpu': True,
212212
'numerical_distributions': {},
213213
'default_distribution': 'beta',
214214
}

0 commit comments

Comments
 (0)