Skip to content

Commit 0ad909e

Browse files
committed
.
1 parent 118f3b6 commit 0ad909e

File tree

1 file changed

+42
-58
lines changed

1 file changed

+42
-58
lines changed

ctgan/synthesizers/ctgan.py

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,14 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
3939
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
4040

4141
disc_interpolates = self(interpolates)
42-
self.set_device(device)
43-
gradients = torch.autograd.grad(
44-
outputs=disc_interpolates,
45-
inputs=interpolates,
46-
grad_outputs=torch.ones(disc_interpolates.size(), device=device),
47-
create_graph=True,
48-
retain_graph=True,
49-
only_inputs=True,
50-
)[0]
42+
43+
with warnings.catch_warnings():
44+
warnings.simplefilter('ignore', category=UserWarning)
45+
gradients = torch.autograd.grad(
46+
outputs=disc_interpolates, inputs=interpolates,
47+
grad_outputs=torch.ones(disc_interpolates.size(), device=device),
48+
create_graph=True, retain_graph=True, only_inputs=True
49+
)[0]
5150

5251
gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
5352
gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
@@ -143,23 +142,11 @@ class CTGAN(BaseSynthesizer):
143142
Defaults to ``True``.
144143
"""
145144

146-
def __init__(
147-
self,
148-
embedding_dim=128,
149-
generator_dim=(256, 256),
150-
discriminator_dim=(256, 256),
151-
generator_lr=2e-4,
152-
generator_decay=1e-6,
153-
discriminator_lr=2e-4,
154-
discriminator_decay=1e-6,
155-
batch_size=500,
156-
discriminator_steps=1,
157-
log_frequency=True,
158-
verbose=False,
159-
epochs=300,
160-
pac=10,
161-
cuda=True,
162-
):
145+
def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
146+
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
147+
discriminator_decay=1e-6, batch_size=500, discriminator_steps=1,
148+
log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True):
149+
163150
assert batch_size % 2 == 0
164151

165152
self._embedding_dim = embedding_dim
@@ -254,7 +241,9 @@ def _cond_loss(self, data, c, m):
254241
ed = st + span_info.dim
255242
ed_c = st_c + span_info.dim
256243
tmp = functional.cross_entropy(
257-
data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
244+
data[:, st:ed],
245+
torch.argmax(c[:, st_c:ed_c], dim=1),
246+
reduction='none'
258247
)
259248
loss.append(tmp)
260249
st = ed
@@ -308,11 +297,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
308297
epochs = self._epochs
309298
else:
310299
warnings.warn(
311-
(
312-
'`epochs` argument in `fit` method has been deprecated and will be removed '
313-
'in a future version. Please pass `epochs` to the constructor instead'
314-
),
315-
DeprecationWarning,
300+
('`epochs` argument in `fit` method has been deprecated and will be removed '
301+
'in a future version. Please pass `epochs` to the constructor instead'),
302+
DeprecationWarning
316303
)
317304

318305
self._transformer = DataTransformer()
@@ -321,31 +308,32 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
321308
train_data = self._transformer.transform(train_data)
322309

323310
self._data_sampler = DataSampler(
324-
train_data, self._transformer.output_info_list, self._log_frequency
325-
)
311+
train_data,
312+
self._transformer.output_info_list,
313+
self._log_frequency)
326314

327315
data_dim = self._transformer.output_dimensions
328316

329317
self._generator = Generator(
330-
self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
318+
self._embedding_dim + self._data_sampler.dim_cond_vec(),
319+
self._generator_dim,
320+
data_dim
331321
).to(self._device)
332322

333323
discriminator = Discriminator(
334-
data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
324+
data_dim + self._data_sampler.dim_cond_vec(),
325+
self._discriminator_dim,
326+
pac=self.pac
335327
).to(self._device)
336328

337329
optimizerG = optim.Adam(
338-
self._generator.parameters(),
339-
lr=self._generator_lr,
340-
betas=(0.5, 0.9),
341-
weight_decay=self._generator_decay,
330+
self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9),
331+
weight_decay=self._generator_decay
342332
)
343333

344334
optimizerD = optim.Adam(
345-
discriminator.parameters(),
346-
lr=self._discriminator_lr,
347-
betas=(0.5, 0.9),
348-
weight_decay=self._discriminator_decay,
335+
discriminator.parameters(), lr=self._discriminator_lr,
336+
betas=(0.5, 0.9), weight_decay=self._discriminator_decay
349337
)
350338

351339
mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
@@ -361,15 +349,15 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
361349
steps_per_epoch = max(len(train_data) // self._batch_size, 1)
362350
for i in epoch_iterator:
363351
for id_ in range(steps_per_epoch):
352+
364353
for n in range(self._discriminator_steps):
365354
fakez = torch.normal(mean=mean, std=std)
366355

367356
condvec = self._data_sampler.sample_condvec(self._batch_size)
368357
if condvec is None:
369358
c1, m1, col, opt = None, None, None, None
370359
real = self._data_sampler.sample_data(
371-
train_data, self._batch_size, col, opt
372-
)
360+
train_data, self._batch_size, col, opt)
373361
else:
374362
c1, m1, col, opt = condvec
375363
c1 = torch.from_numpy(c1).to(self._device)
@@ -379,8 +367,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
379367
perm = np.arange(self._batch_size)
380368
np.random.shuffle(perm)
381369
real = self._data_sampler.sample_data(
382-
train_data, self._batch_size, col[perm], opt[perm]
383-
)
370+
train_data, self._batch_size, col[perm], opt[perm])
384371
c2 = c1[perm]
385372

386373
fake = self._generator(fakez)
@@ -399,8 +386,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
399386
y_real = discriminator(real_cat)
400387

401388
pen = discriminator.calc_gradient_penalty(
402-
real_cat, fake_cat, self._device, self.pac
403-
)
389+
real_cat, fake_cat, self._device, self.pac)
404390
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
405391

406392
optimizerD.zero_grad(set_to_none=False)
@@ -444,12 +430,12 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
444430
epoch_loss_df = pd.DataFrame({
445431
'Epoch': [i],
446432
'Generator Loss': [generator_loss],
447-
'Discriminator Loss': [discriminator_loss],
433+
'Discriminator Loss': [discriminator_loss]
448434
})
449435
if not self.loss_values.empty:
450-
self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
451-
drop=True
452-
)
436+
self.loss_values = pd.concat(
437+
[self.loss_values, epoch_loss_df]
438+
).reset_index(drop=True)
453439
else:
454440
self.loss_values = epoch_loss_df
455441

@@ -479,11 +465,9 @@ def sample(self, n, condition_column=None, condition_value=None):
479465
"""
480466
if condition_column is not None and condition_value is not None:
481467
condition_info = self._transformer.convert_column_name_value_to_id(
482-
condition_column, condition_value
483-
)
468+
condition_column, condition_value)
484469
global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
485-
condition_info, self._batch_size
486-
)
470+
condition_info, self._batch_size)
487471
else:
488472
global_condition_vec = None
489473

0 commit comments

Comments
 (0)