@@ -39,15 +39,14 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
3939 interpolates = alpha * real_data + ((1 - alpha ) * fake_data )
4040
4141 disc_interpolates = self (interpolates )
42- self .set_device (device )
43- gradients = torch .autograd .grad (
44- outputs = disc_interpolates ,
45- inputs = interpolates ,
46- grad_outputs = torch .ones (disc_interpolates .size (), device = device ),
47- create_graph = True ,
48- retain_graph = True ,
49- only_inputs = True ,
50- )[0 ]
42+
43+ with warnings .catch_warnings ():
44+ warnings .simplefilter ('ignore' , category = UserWarning )
45+ gradients = torch .autograd .grad (
46+ outputs = disc_interpolates , inputs = interpolates ,
47+ grad_outputs = torch .ones (disc_interpolates .size (), device = device ),
48+ create_graph = True , retain_graph = True , only_inputs = True
49+ )[0 ]
5150
5251 gradients_view = gradients .view (- 1 , pac * real_data .size (1 )).norm (2 , dim = 1 ) - 1
5352 gradient_penalty = ((gradients_view ) ** 2 ).mean () * lambda_
@@ -143,23 +142,11 @@ class CTGAN(BaseSynthesizer):
143142 Defaults to ``True``.
144143 """
145144
146- def __init__ (
147- self ,
148- embedding_dim = 128 ,
149- generator_dim = (256 , 256 ),
150- discriminator_dim = (256 , 256 ),
151- generator_lr = 2e-4 ,
152- generator_decay = 1e-6 ,
153- discriminator_lr = 2e-4 ,
154- discriminator_decay = 1e-6 ,
155- batch_size = 500 ,
156- discriminator_steps = 1 ,
157- log_frequency = True ,
158- verbose = False ,
159- epochs = 300 ,
160- pac = 10 ,
161- cuda = True ,
162- ):
145+ def __init__ (self , embedding_dim = 128 , generator_dim = (256 , 256 ), discriminator_dim = (256 , 256 ),
146+ generator_lr = 2e-4 , generator_decay = 1e-6 , discriminator_lr = 2e-4 ,
147+ discriminator_decay = 1e-6 , batch_size = 500 , discriminator_steps = 1 ,
148+ log_frequency = True , verbose = False , epochs = 300 , pac = 10 , cuda = True ):
149+
163150 assert batch_size % 2 == 0
164151
165152 self ._embedding_dim = embedding_dim
@@ -254,7 +241,9 @@ def _cond_loss(self, data, c, m):
254241 ed = st + span_info .dim
255242 ed_c = st_c + span_info .dim
256243 tmp = functional .cross_entropy (
257- data [:, st :ed ], torch .argmax (c [:, st_c :ed_c ], dim = 1 ), reduction = 'none'
244+ data [:, st :ed ],
245+ torch .argmax (c [:, st_c :ed_c ], dim = 1 ),
246+ reduction = 'none'
258247 )
259248 loss .append (tmp )
260249 st = ed
@@ -308,11 +297,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
308297 epochs = self ._epochs
309298 else :
310299 warnings .warn (
311- (
312- '`epochs` argument in `fit` method has been deprecated and will be removed '
313- 'in a future version. Please pass `epochs` to the constructor instead'
314- ),
315- DeprecationWarning ,
300+ ('`epochs` argument in `fit` method has been deprecated and will be removed '
301+ 'in a future version. Please pass `epochs` to the constructor instead' ),
302+ DeprecationWarning
316303 )
317304
318305 self ._transformer = DataTransformer ()
@@ -321,31 +308,32 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
321308 train_data = self ._transformer .transform (train_data )
322309
323310 self ._data_sampler = DataSampler (
324- train_data , self ._transformer .output_info_list , self ._log_frequency
325- )
311+ train_data ,
312+ self ._transformer .output_info_list ,
313+ self ._log_frequency )
326314
327315 data_dim = self ._transformer .output_dimensions
328316
329317 self ._generator = Generator (
330- self ._embedding_dim + self ._data_sampler .dim_cond_vec (), self ._generator_dim , data_dim
318+ self ._embedding_dim + self ._data_sampler .dim_cond_vec (),
319+ self ._generator_dim ,
320+ data_dim
331321 ).to (self ._device )
332322
333323 discriminator = Discriminator (
334- data_dim + self ._data_sampler .dim_cond_vec (), self ._discriminator_dim , pac = self .pac
324+ data_dim + self ._data_sampler .dim_cond_vec (),
325+ self ._discriminator_dim ,
326+ pac = self .pac
335327 ).to (self ._device )
336328
337329 optimizerG = optim .Adam (
338- self ._generator .parameters (),
339- lr = self ._generator_lr ,
340- betas = (0.5 , 0.9 ),
341- weight_decay = self ._generator_decay ,
330+ self ._generator .parameters (), lr = self ._generator_lr , betas = (0.5 , 0.9 ),
331+ weight_decay = self ._generator_decay
342332 )
343333
344334 optimizerD = optim .Adam (
345- discriminator .parameters (),
346- lr = self ._discriminator_lr ,
347- betas = (0.5 , 0.9 ),
348- weight_decay = self ._discriminator_decay ,
335+ discriminator .parameters (), lr = self ._discriminator_lr ,
336+ betas = (0.5 , 0.9 ), weight_decay = self ._discriminator_decay
349337 )
350338
351339 mean = torch .zeros (self ._batch_size , self ._embedding_dim , device = self ._device )
@@ -361,15 +349,15 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
361349 steps_per_epoch = max (len (train_data ) // self ._batch_size , 1 )
362350 for i in epoch_iterator :
363351 for id_ in range (steps_per_epoch ):
352+
364353 for n in range (self ._discriminator_steps ):
365354 fakez = torch .normal (mean = mean , std = std )
366355
367356 condvec = self ._data_sampler .sample_condvec (self ._batch_size )
368357 if condvec is None :
369358 c1 , m1 , col , opt = None , None , None , None
370359 real = self ._data_sampler .sample_data (
371- train_data , self ._batch_size , col , opt
372- )
360+ train_data , self ._batch_size , col , opt )
373361 else :
374362 c1 , m1 , col , opt = condvec
375363 c1 = torch .from_numpy (c1 ).to (self ._device )
@@ -379,8 +367,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
379367 perm = np .arange (self ._batch_size )
380368 np .random .shuffle (perm )
381369 real = self ._data_sampler .sample_data (
382- train_data , self ._batch_size , col [perm ], opt [perm ]
383- )
370+ train_data , self ._batch_size , col [perm ], opt [perm ])
384371 c2 = c1 [perm ]
385372
386373 fake = self ._generator (fakez )
@@ -399,8 +386,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
399386 y_real = discriminator (real_cat )
400387
401388 pen = discriminator .calc_gradient_penalty (
402- real_cat , fake_cat , self ._device , self .pac
403- )
389+ real_cat , fake_cat , self ._device , self .pac )
404390 loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
405391
406392 optimizerD .zero_grad (set_to_none = False )
@@ -444,12 +430,12 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
444430 epoch_loss_df = pd .DataFrame ({
445431 'Epoch' : [i ],
446432 'Generator Loss' : [generator_loss ],
447- 'Discriminator Loss' : [discriminator_loss ],
433+ 'Discriminator Loss' : [discriminator_loss ]
448434 })
449435 if not self .loss_values .empty :
450- self .loss_values = pd .concat ([ self . loss_values , epoch_loss_df ]). reset_index (
451- drop = True
452- )
436+ self .loss_values = pd .concat (
437+ [ self . loss_values , epoch_loss_df ]
438+ ). reset_index ( drop = True )
453439 else :
454440 self .loss_values = epoch_loss_df
455441
@@ -479,11 +465,9 @@ def sample(self, n, condition_column=None, condition_value=None):
479465 """
480466 if condition_column is not None and condition_value is not None :
481467 condition_info = self ._transformer .convert_column_name_value_to_id (
482- condition_column , condition_value
483- )
468+ condition_column , condition_value )
484469 global_condition_vec = self ._data_sampler .generate_cond_from_condition_column_info (
485- condition_info , self ._batch_size
486- )
470+ condition_info , self ._batch_size )
487471 else :
488472 global_condition_vec = None
489473
0 commit comments