Skip to content

Commit bb59e91

Browse files
authored
Enable GPU usage for MacOS (using MPS) (#464)
1 parent ccd23ca commit bb59e91

File tree

9 files changed

+206
-30
lines changed

9 files changed

+206
-30
lines changed

ctgan/synthesizers/_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import platform
2+
import warnings
3+
4+
import torch
5+
6+
7+
def _get_enable_gpu_value(enable_gpu, cuda):
8+
"""Validate both the `enable_gpu` and `cuda` parameters.
9+
10+
The logic here is to:
11+
- Raise a warning if `cuda` is set because it's deprecated.
12+
- Raise an error if both parameters are set in a conflicting way.
13+
- Return the resolved `enable_gpu` value.
14+
"""
15+
if cuda is not None:
16+
warnings.warn(
17+
'`cuda` parameter is deprecated and will be removed in a future release. '
18+
'Please use `enable_gpu` instead.',
19+
FutureWarning,
20+
)
21+
if not enable_gpu:
22+
raise ValueError(
23+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
24+
'Please use only `enable_gpu`.'
25+
)
26+
27+
enable_gpu = cuda
28+
29+
return enable_gpu
30+
31+
32+
def _set_device(enable_gpu, device=None):
33+
"""Set the torch device based on the `enable_gpu` parameter and system capabilities."""
34+
if device:
35+
return torch.device(device)
36+
37+
if enable_gpu:
38+
if platform.system() == 'Darwin': # macOS
39+
if (
40+
platform.machine() == 'arm64'
41+
and getattr(torch.backends, 'mps', None)
42+
and torch.backends.mps.is_available()
43+
):
44+
device = 'mps'
45+
else:
46+
device = 'cpu'
47+
else: # Linux/Windows
48+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
49+
else:
50+
device = 'cpu'
51+
52+
return torch.device(device)
53+
54+
55+
def validate_and_set_device(enable_gpu, cuda):
56+
enable_gpu = _get_enable_gpu_value(enable_gpu, cuda)
57+
return _set_device(enable_gpu)

ctgan/synthesizers/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import torch
77

8+
from ctgan.synthesizers._utils import _set_device
9+
810

911
@contextlib.contextmanager
1012
def set_random_states(random_state, set_model_random_state):
@@ -105,7 +107,7 @@ def __setstate__(self, state):
105107
state['random_states'] = (current_numpy_state, current_torch_state)
106108

107109
self.__dict__ = state
108-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
110+
device = _set_device(enable_gpu=True)
109111
self.set_device(device)
110112

111113
def save(self, path):
@@ -118,7 +120,7 @@ def save(self, path):
118120
@classmethod
119121
def load(cls, path):
120122
"""Load the model stored in the passed `path`."""
121-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
123+
device = _set_device(enable_gpu=True)
122124
model = torch.load(path, weights_only=False)
123125
model.set_device(device)
124126
return model

ctgan/synthesizers/ctgan.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ctgan.data_sampler import DataSampler
1313
from ctgan.data_transformer import DataTransformer
1414
from ctgan.errors import InvalidDataError
15+
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
1516
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1617

1718

@@ -138,8 +139,11 @@ class CTGAN(BaseSynthesizer):
138139
pac (int):
139140
Number of samples to group together when applying the discriminator.
140141
Defaults to 10.
142+
enable_gpu (bool):
143+
Whether to attempt to use GPU for computation.
144+
Defaults to ``True``.
141145
cuda (bool):
142-
Whether to attempt to use cuda for GPU computation.
146+
**Deprecated** Whether to attempt to use cuda for GPU computation.
143147
If this is False or CUDA is not available, CPU will be used.
144148
Defaults to ``True``.
145149
"""
@@ -159,7 +163,8 @@ def __init__(
159163
verbose=False,
160164
epochs=300,
161165
pac=10,
162-
cuda=True,
166+
enable_gpu=True,
167+
cuda=None,
163168
):
164169
assert batch_size % 2 == 0
165170

@@ -178,16 +183,8 @@ def __init__(
178183
self._verbose = verbose
179184
self._epochs = epochs
180185
self.pac = pac
181-
182-
if not cuda or not torch.cuda.is_available():
183-
device = 'cpu'
184-
elif isinstance(cuda, str):
185-
device = cuda
186-
else:
187-
device = 'cuda'
188-
189-
self._device = torch.device(device)
190-
186+
self._device = validate_and_set_device(enable_gpu, cuda)
187+
self._enable_gpu = cuda if cuda is not None else enable_gpu
191188
self._transformer = None
192189
self._data_sampler = None
193190
self._generator = None
@@ -544,6 +541,7 @@ def sample(self, n, condition_column=None, condition_value=None):
544541

545542
def set_device(self, device):
546543
"""Set the `device` to be used ('GPU' or 'CPU)."""
547-
self._device = device
544+
enable_gpu = getattr(self, '_enable_gpu', True)
545+
self._device = _set_device(enable_gpu, device)
548546
if self._generator is not None:
549547
self._generator.to(self._device)

ctgan/synthesizers/tvae.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tqdm import tqdm
1111

1212
from ctgan.data_transformer import DataTransformer
13+
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
1314
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1415

1516

@@ -114,8 +115,9 @@ def __init__(
114115
batch_size=500,
115116
epochs=300,
116117
loss_factor=2,
117-
cuda=True,
118+
enable_gpu=True,
118119
verbose=False,
120+
cuda=None,
119121
):
120122
self.embedding_dim = embedding_dim
121123
self.compress_dims = compress_dims
@@ -127,15 +129,8 @@ def __init__(
127129
self.epochs = epochs
128130
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
129131
self.verbose = verbose
130-
131-
if not cuda or not torch.cuda.is_available():
132-
device = 'cpu'
133-
elif isinstance(cuda, str):
134-
device = cuda
135-
else:
136-
device = 'cuda'
137-
138-
self._device = torch.device(device)
132+
self._device = validate_and_set_device(enable_gpu, cuda)
133+
self._enable_gpu = cuda if cuda is not None else enable_gpu
139134

140135
@random_state
141136
def fit(self, train_data, discrete_columns=()):
@@ -241,6 +236,7 @@ def sample(self, samples):
241236
return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())
242237

243238
def set_device(self, device):
244-
"""Set the `device` to be used ('GPU' or 'CPU)."""
245-
self._device = device
239+
"""Set the `device` to be used ('GPU' or 'CPU')."""
240+
enable_gpu = getattr(self, '_enable_gpu', True)
241+
self._device = _set_device(enable_gpu, device)
246242
self.decoder.to(self._device)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"pandas>=2.2.3;python_version>='3.13'",
3333
"torch>=1.13.0;python_version<'3.11'",
3434
"torch>=2.0.0;python_version>='3.11' and python_version<'3.12'",
35-
"torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
35+
"torch>=2.3.0;python_version>='3.12' and python_version<'3.13'",
3636
"torch>=2.6.0;python_version>='3.13'",
3737
'tqdm>=4.29,<5',
3838
'rdt>=1.14.0',

tests/integration/synthesizer/test_ctgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_fixed_random_seed():
236236
})
237237
discrete_columns = ['discrete']
238238

239-
ctgan = CTGAN(epochs=1, cuda=False)
239+
ctgan = CTGAN(epochs=1, enable_gpu=False)
240240

241241
# Run
242242
ctgan.fit(data, discrete_columns)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import platform
2+
import re
3+
from unittest.mock import patch
4+
5+
import pytest
6+
import torch
7+
8+
from ctgan.synthesizers._utils import _get_enable_gpu_value, _set_device, validate_and_set_device
9+
10+
11+
def test__validate_gpu_parameter():
12+
"""Test the ``_get_enable_gpu_value`` method."""
13+
# Setup
14+
expected_error = re.escape(
15+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
16+
'Please use only `enable_gpu`.'
17+
)
18+
expected_warning = re.escape(
19+
'`cuda` parameter is deprecated and will be removed in a future release. '
20+
'Please use `enable_gpu` instead.'
21+
)
22+
23+
# Run
24+
enable_gpu_1 = _get_enable_gpu_value(False, None)
25+
enable_gpu_2 = _get_enable_gpu_value(True, None)
26+
with pytest.warns(FutureWarning, match=expected_warning):
27+
enable_gpu_3 = _get_enable_gpu_value(True, False)
28+
29+
with pytest.raises(ValueError, match=expected_error):
30+
_get_enable_gpu_value(False, True)
31+
32+
# Assert
33+
assert enable_gpu_1 is False
34+
assert enable_gpu_2 is True
35+
assert enable_gpu_3 is False
36+
37+
38+
def test__set_device():
39+
"""Test the ``_set_device`` method."""
40+
# Run
41+
device_1 = _set_device(False)
42+
device_2 = _set_device(True)
43+
device_3 = _set_device(True, 'cpu')
44+
device_4 = _set_device(enable_gpu=False, device='cpu')
45+
46+
# Assert
47+
if (
48+
platform.machine() == 'arm64'
49+
and getattr(torch.backends, 'mps', None)
50+
and torch.backends.mps.is_available()
51+
):
52+
expected_device_2 = torch.device('mps')
53+
elif torch.cuda.is_available():
54+
expected_device_2 = torch.device('cuda')
55+
else:
56+
expected_device_2 = torch.device('cpu')
57+
58+
assert device_1 == torch.device('cpu')
59+
assert device_2 == expected_device_2
60+
assert device_3 == torch.device('cpu')
61+
assert device_4 == torch.device('cpu')
62+
63+
64+
@patch('ctgan.synthesizers._utils._set_device')
65+
@patch('ctgan.synthesizers._utils._get_enable_gpu_value')
66+
def test_validate_and_set_device(mock_validate, mock_set_device):
67+
"""Test the ``validate_and_set_device`` method."""
68+
# Setup
69+
mock_validate.return_value = True
70+
mock_set_device.return_value = torch.device('cuda')
71+
72+
# Run
73+
device = validate_and_set_device(enable_gpu=True, cuda=None)
74+
75+
# Assert
76+
mock_validate.assert_called_once_with(True, None)
77+
mock_set_device.assert_called_once_with(True)
78+
assert device == torch.device('cuda')

tests/unit/synthesizer/test_ctgan.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""CTGAN unit testing module."""
22

33
from unittest import TestCase
4-
from unittest.mock import Mock
4+
from unittest.mock import Mock, patch
55

66
import numpy as np
77
import pandas as pd
@@ -175,6 +175,30 @@ def _assert_is_between(data, lower, upper):
175175

176176

177177
class TestCTGAN(TestCase):
178+
@patch('ctgan.synthesizers.ctgan.validate_and_set_device')
179+
def test___init__(self, mock_validate_and_set_device):
180+
"""Test the `__init__` method."""
181+
# Setup
182+
mock_validate_and_set_device.return_value = 'cpu'
183+
184+
# Run
185+
synth = CTGAN()
186+
187+
# Assert
188+
assert synth._embedding_dim == 128
189+
assert synth._generator_dim == (256, 256)
190+
assert synth._discriminator_dim == (256, 256)
191+
assert synth._batch_size == 500
192+
assert synth._epochs == 300
193+
assert synth.pac == 10
194+
assert synth.loss_values is None
195+
assert synth._generator is None
196+
assert synth._data_sampler is None
197+
assert synth._verbose is False
198+
assert synth._enable_gpu is True
199+
assert synth._device == 'cpu'
200+
mock_validate_and_set_device.assert_called_once_with(True, None)
201+
178202
def test__apply_activate_(self):
179203
"""Test `_apply_activate` for tables with both continuous and categoricals.
180204

tests/unit/synthesizer/test_tvae.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,27 @@
88

99

1010
class TestTVAE:
11+
@patch('ctgan.synthesizers.tvae.validate_and_set_device')
12+
def test___init__(self, mock_validate_and_set_device):
13+
"""Test the `__init__` method."""
14+
# Setup
15+
mock_validate_and_set_device.return_value = 'cpu'
16+
17+
# Run
18+
synth = TVAE()
19+
20+
# Assert
21+
assert synth.embedding_dim == 128
22+
assert synth.compress_dims == (128, 128)
23+
assert synth.decompress_dims == (128, 128)
24+
assert synth.batch_size == 500
25+
assert synth.epochs == 300
26+
assert synth.loss_values.equals(pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']))
27+
assert synth.verbose is False
28+
assert synth._enable_gpu is True
29+
assert synth._device == 'cpu'
30+
mock_validate_and_set_device.assert_called_once_with(True, None)
31+
1132
@patch('ctgan.synthesizers.tvae._loss_function')
1233
@patch('ctgan.synthesizers.tvae.tqdm')
1334
def test_fit_verbose(self, tqdm_mock, loss_func_mock):

0 commit comments

Comments
 (0)