1212from ctgan .data_sampler import DataSampler
1313from ctgan .data_transformer import DataTransformer
1414from ctgan .errors import InvalidDataError
15+ from ctgan .synthesizers ._utils import _set_device , validate_and_set_device
1516from ctgan .synthesizers .base import BaseSynthesizer , random_state
1617
1718
@@ -138,8 +139,11 @@ class CTGAN(BaseSynthesizer):
138139 pac (int):
139140 Number of samples to group together when applying the discriminator.
140141 Defaults to 10.
142+ enable_gpu (bool):
143+ Whether to attempt to use GPU for computation.
144+ Defaults to ``True``.
141145 cuda (bool):
142- Whether to attempt to use cuda for GPU computation.
146+ ** Deprecated ** Whether to attempt to use cuda for GPU computation.
143147 If this is False or CUDA is not available, CPU will be used.
144148 Defaults to ``True``.
145149 """
@@ -159,7 +163,8 @@ def __init__(
159163 verbose = False ,
160164 epochs = 300 ,
161165 pac = 10 ,
162- cuda = True ,
166+ enable_gpu = True ,
167+ cuda = None ,
163168 ):
164169 assert batch_size % 2 == 0
165170
@@ -178,16 +183,8 @@ def __init__(
178183 self ._verbose = verbose
179184 self ._epochs = epochs
180185 self .pac = pac
181-
182- if not cuda or not torch .cuda .is_available ():
183- device = 'cpu'
184- elif isinstance (cuda , str ):
185- device = cuda
186- else :
187- device = 'cuda'
188-
189- self ._device = torch .device (device )
190-
186+ self ._device = validate_and_set_device (enable_gpu , cuda )
187+ self ._enable_gpu = cuda if cuda is not None else enable_gpu
191188 self ._transformer = None
192189 self ._data_sampler = None
193190 self ._generator = None
@@ -544,6 +541,6 @@ def sample(self, n, condition_column=None, condition_value=None):
544541
545542 def set_device (self , device ):
546543 """Set the `device` to be used ('GPU' or 'CPU)."""
547- self ._device = device
544+ self ._device = _set_device ( self . _enable_gpu , device )
548545 if self ._generator is not None :
549546 self ._generator .to (self ._device )
0 commit comments