Skip to content

Commit 1938413

Browse files
committed
def 462
1 parent ccd23ca commit 1938413

File tree

4 files changed

+73
-26
lines changed

4 files changed

+73
-26
lines changed

ctgan/synthesizers/_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import sys
2+
import warnings
3+
4+
import torch
5+
6+
7+
def _validate_gpu_parameters(enable_gpu, cuda):
8+
"""Validate both the `enable_gpu` and `cuda` parameters.
9+
10+
The logic here is to:
11+
- Raise a warning if `cuda` is set because it's deprecated.
12+
- Raise an error if both parameters are set in a conflicting way.
13+
- Return the resolved `enable_gpu` value.
14+
"""
15+
if cuda is not None:
16+
warnings.warn(
17+
'`cuda` parameter is deprecated and will be removed in a future release. '
18+
'Please use `enable_gpu` instead.',
19+
FutureWarning,
20+
)
21+
if not enable_gpu:
22+
raise ValueError(
23+
'Cannot resolve the provided values of `cuda` and `enable_gpu` parameters. '
24+
'Please use only `enable_gpu`.'
25+
)
26+
27+
enable_gpu = cuda
28+
29+
return enable_gpu
30+
31+
32+
def _set_device(enable_gpu, device=None):
33+
"""Set the torch device based on the `enable_gpu` parameter and system capabilities."""
34+
if device:
35+
return torch.device(device)
36+
37+
if enable_gpu:
38+
if sys.platform == 'darwin': # macOS
39+
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
40+
device = 'mps'
41+
else:
42+
device = 'cpu'
43+
else: # Linux/Windows
44+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
45+
else:
46+
device = 'cpu'
47+
48+
return torch.device(device)
49+
50+
51+
def validate_and_set_device(enable_gpu, cuda):
52+
enable_gpu = _validate_gpu_parameters(enable_gpu, cuda)
53+
return _set_device(enable_gpu)

ctgan/synthesizers/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import torch
77

8+
from ctgan.synthesizers._utils import _set_device
9+
810

911
@contextlib.contextmanager
1012
def set_random_states(random_state, set_model_random_state):
@@ -118,7 +120,7 @@ def save(self, path):
118120
@classmethod
119121
def load(cls, path):
120122
"""Load the model stored in the passed `path`."""
121-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
123+
device = _set_device(enable_gpu=True)
122124
model = torch.load(path, weights_only=False)
123125
model.set_device(device)
124126
return model

ctgan/synthesizers/ctgan.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ctgan.data_sampler import DataSampler
1313
from ctgan.data_transformer import DataTransformer
1414
from ctgan.errors import InvalidDataError
15+
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
1516
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1617

1718

@@ -138,8 +139,11 @@ class CTGAN(BaseSynthesizer):
138139
pac (int):
139140
Number of samples to group together when applying the discriminator.
140141
Defaults to 10.
142+
enable_gpu (bool):
143+
Whether to attempt to use GPU for computation.
144+
Defaults to ``True``.
141145
cuda (bool):
142-
Whether to attempt to use cuda for GPU computation.
146+
** Deprecated ** Whether to attempt to use cuda for GPU computation.
143147
If this is False or CUDA is not available, CPU will be used.
144148
Defaults to ``True``.
145149
"""
@@ -159,7 +163,8 @@ def __init__(
159163
verbose=False,
160164
epochs=300,
161165
pac=10,
162-
cuda=True,
166+
enable_gpu=True,
167+
cuda=None,
163168
):
164169
assert batch_size % 2 == 0
165170

@@ -178,16 +183,8 @@ def __init__(
178183
self._verbose = verbose
179184
self._epochs = epochs
180185
self.pac = pac
181-
182-
if not cuda or not torch.cuda.is_available():
183-
device = 'cpu'
184-
elif isinstance(cuda, str):
185-
device = cuda
186-
else:
187-
device = 'cuda'
188-
189-
self._device = torch.device(device)
190-
186+
self._device = validate_and_set_device(enable_gpu, cuda)
187+
self._enable_gpu = cuda if cuda is not None else enable_gpu
191188
self._transformer = None
192189
self._data_sampler = None
193190
self._generator = None
@@ -544,6 +541,6 @@ def sample(self, n, condition_column=None, condition_value=None):
544541

545542
def set_device(self, device):
546543
"""Set the `device` to be used ('GPU' or 'CPU)."""
547-
self._device = device
544+
self._device = _set_device(self._enable_gpu, device)
548545
if self._generator is not None:
549546
self._generator.to(self._device)

ctgan/synthesizers/tvae.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tqdm import tqdm
1111

1212
from ctgan.data_transformer import DataTransformer
13+
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
1314
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1415

1516

@@ -114,8 +115,9 @@ def __init__(
114115
batch_size=500,
115116
epochs=300,
116117
loss_factor=2,
117-
cuda=True,
118+
enable_gpu=True,
118119
verbose=False,
120+
cuda=None,
119121
):
120122
self.embedding_dim = embedding_dim
121123
self.compress_dims = compress_dims
@@ -127,15 +129,8 @@ def __init__(
127129
self.epochs = epochs
128130
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
129131
self.verbose = verbose
130-
131-
if not cuda or not torch.cuda.is_available():
132-
device = 'cpu'
133-
elif isinstance(cuda, str):
134-
device = cuda
135-
else:
136-
device = 'cuda'
137-
138-
self._device = torch.device(device)
132+
self._device = validate_and_set_device(enable_gpu, cuda)
133+
self._enable_gpu = cuda if cuda is not None else enable_gpu
139134

140135
@random_state
141136
def fit(self, train_data, discrete_columns=()):
@@ -241,6 +236,6 @@ def sample(self, samples):
241236
return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())
242237

243238
def set_device(self, device):
244-
"""Set the `device` to be used ('GPU' or 'CPU)."""
245-
self._device = device
239+
"""Set the `device` to be used ('GPU' or 'CPU')."""
240+
self._device = _set_device(self._enable_gpu, device)
246241
self.decoder.to(self._device)

0 commit comments

Comments
 (0)