Skip to content

Commit ee9adeb

Browse files
committed
pytorch lightning fixes
This reverts commit 02d9bb6.
1 parent 5725511 commit ee9adeb

File tree

2 files changed

+42
-52
lines changed

2 files changed

+42
-52
lines changed

decaf/DECAF.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -346,76 +346,66 @@ def get_gen_order(self) -> list:
346346
return gen_order
347347

348348
def training_step(
349-
self, batch: torch.Tensor, batch_idx: int, **kwargs
349+
self, batch: torch.Tensor, batch_idx: int, optimizer_idx: int
350350
) -> OrderedDict:
351-
optimizer_g, optimizer_d = self.optimizers()
352-
353351
# sample noise
354352
z = self.sample_z(batch.shape[0])
355353
z = z.type_as(batch)
356354
generated_batch = self.generator.sequential(batch, z, self.get_gen_order())
357355

358-
# train disc
359-
#
360-
self.toggle_optimizer(optimizer_d)
361-
self.iterations_d += 1
362-
# Measure discriminator's ability to classify real from generated samples
356+
# train generator
357+
if optimizer_idx == 0:
358+
self.iterations_d += 1
359+
# Measure discriminator's ability to classify real from generated samples
363360

364-
# how well can it label as real?
365-
real_loss = torch.mean(self.discriminator(batch))
366-
fake_loss = torch.mean(self.discriminator(generated_batch.detach()))
361+
# how well can it label as real?
362+
real_loss = torch.mean(self.discriminator(batch))
363+
fake_loss = torch.mean(self.discriminator(generated_batch.detach()))
367364

368-
# discriminator loss
369-
d_loss = fake_loss - real_loss
365+
# discriminator loss
366+
d_loss = fake_loss - real_loss
370367

371-
# add the gradient penalty
372-
d_loss += self.hparams.lambda_gp * self.compute_gradient_penalty(
373-
batch, generated_batch
374-
)
375-
if torch.isnan(d_loss).sum() != 0:
376-
raise ValueError("NaN in the discr loss")
377-
378-
self.manual_backward(d_loss)
379-
optimizer_d.step()
380-
optimizer_d.zero_grad()
381-
self.untoggle_optimizer(optimizer_d)
382-
# train gen
383-
#
384-
# sanity check: keep track of G updates
385-
self.toggle_optimizer(optimizer_g)
386-
self.iterations_g += 1
387-
388-
# adversarial loss (negative D fake loss)
389-
generated_batch = self.generator.sequential(batch, z, self.get_gen_order())
390-
g_loss = -torch.mean(
391-
self.discriminator(generated_batch)
392-
) # self.adversarial_loss(self.discriminator(self.generated_batch), valid)
368+
# add the gradient penalty
369+
d_loss += self.hparams.lambda_gp * self.compute_gradient_penalty(
370+
batch, generated_batch
371+
)
372+
if torch.isnan(d_loss).sum() != 0:
373+
raise ValueError("NaN in the discr loss")
374+
375+
return d_loss
376+
elif optimizer_idx == 1:
377+
# sanity check: keep track of G updates
378+
self.iterations_g += 1
379+
380+
# adversarial loss (negative D fake loss)
381+
g_loss = -torch.mean(
382+
self.discriminator(generated_batch)
383+
) # self.adversarial_loss(self.discriminator(self.generated_batch), valid)
384+
385+
# add privacy loss of ADS-GAN
386+
g_loss += self.hparams.lambda_privacy * self.privacy_loss(
387+
batch, generated_batch
388+
)
393389

394-
# add privacy loss of ADS-GAN
395-
g_loss += self.hparams.lambda_privacy * self.privacy_loss(
396-
batch, generated_batch
397-
)
390+
# add l1 regularization loss
391+
g_loss += self.hparams.l1_g * self.l1_reg(self.generator)
398392

399-
# add l1 regularization loss
400-
g_loss += self.hparams.l1_g * self.l1_reg(self.generator)
393+
if len(self.dag_seed) == 0:
394+
if self.hparams.grad_dag_loss:
395+
g_loss += self.gradient_dag_loss(batch, z)
396+
if torch.isnan(g_loss).sum() != 0:
397+
raise ValueError("NaN in the gen loss")
401398

402-
if len(self.dag_seed) == 0:
403-
if self.hparams.grad_dag_loss:
404-
g_loss += self.gradient_dag_loss(batch, z)
405-
if torch.isnan(g_loss).sum() != 0:
406-
raise ValueError("NaN in the gen loss")
407-
self.manual_backward(g_loss)
408-
optimizer_g.step()
409-
optimizer_g.zero_grad()
410-
self.untoggle_optimizer(optimizer_g)
399+
return g_loss
400+
else:
401+
raise ValueError("should not get here")
411402

412403
def configure_optimizers(self) -> tuple:
413404
lr = self.hparams.lr
414405
b1 = self.hparams.b1
415406
b2 = self.hparams.b2
416407
weight_decay = self.hparams.weight_decay
417408

418-
self.automatic_optimization = False
419409
opt_g = torch.optim.AdamW(
420410
self.generator.parameters(),
421411
lr=lr,

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ loguru
22
networkx>=2.0,<3.0
33
numpy>=1.19
44
pandas
5-
pytorch-lightning>=1.4
5+
pytorch-lightning<2.0
66
scikit-learn
77
scipy
88
torch>=1.9

0 commit comments

Comments
 (0)