@@ -346,76 +346,66 @@ def get_gen_order(self) -> list:
346346 return gen_order
347347
348348 def training_step (
349- self , batch : torch .Tensor , batch_idx : int , ** kwargs
349+ self , batch : torch .Tensor , batch_idx : int , optimizer_idx : int
350350 ) -> OrderedDict :
351- optimizer_g , optimizer_d = self .optimizers ()
352-
353351 # sample noise
354352 z = self .sample_z (batch .shape [0 ])
355353 z = z .type_as (batch )
356354 generated_batch = self .generator .sequential (batch , z , self .get_gen_order ())
357355
358- # train disc
359- #
360- self .toggle_optimizer (optimizer_d )
361- self .iterations_d += 1
362- # Measure discriminator's ability to classify real from generated samples
356+ # train generator
357+ if optimizer_idx == 0 :
358+ self .iterations_d += 1
359+ # Measure discriminator's ability to classify real from generated samples
363360
364- # how well can it label as real?
365- real_loss = torch .mean (self .discriminator (batch ))
366- fake_loss = torch .mean (self .discriminator (generated_batch .detach ()))
361+ # how well can it label as real?
362+ real_loss = torch .mean (self .discriminator (batch ))
363+ fake_loss = torch .mean (self .discriminator (generated_batch .detach ()))
367364
368- # discriminator loss
369- d_loss = fake_loss - real_loss
365+ # discriminator loss
366+ d_loss = fake_loss - real_loss
370367
371- # add the gradient penalty
372- d_loss += self .hparams .lambda_gp * self .compute_gradient_penalty (
373- batch , generated_batch
374- )
375- if torch .isnan (d_loss ).sum () != 0 :
376- raise ValueError ("NaN in the discr loss" )
377-
378- self .manual_backward (d_loss )
379- optimizer_d .step ()
380- optimizer_d .zero_grad ()
381- self .untoggle_optimizer (optimizer_d )
382- # train gen
383- #
384- # sanity check: keep track of G updates
385- self .toggle_optimizer (optimizer_g )
386- self .iterations_g += 1
387-
388- # adversarial loss (negative D fake loss)
389- generated_batch = self .generator .sequential (batch , z , self .get_gen_order ())
390- g_loss = - torch .mean (
391- self .discriminator (generated_batch )
392- ) # self.adversarial_loss(self.discriminator(self.generated_batch), valid)
368+ # add the gradient penalty
369+ d_loss += self .hparams .lambda_gp * self .compute_gradient_penalty (
370+ batch , generated_batch
371+ )
372+ if torch .isnan (d_loss ).sum () != 0 :
373+ raise ValueError ("NaN in the discr loss" )
374+
375+ return d_loss
376+ elif optimizer_idx == 1 :
377+ # sanity check: keep track of G updates
378+ self .iterations_g += 1
379+
380+ # adversarial loss (negative D fake loss)
381+ g_loss = - torch .mean (
382+ self .discriminator (generated_batch )
383+ ) # self.adversarial_loss(self.discriminator(self.generated_batch), valid)
384+
385+ # add privacy loss of ADS-GAN
386+ g_loss += self .hparams .lambda_privacy * self .privacy_loss (
387+ batch , generated_batch
388+ )
393389
394- # add privacy loss of ADS-GAN
395- g_loss += self .hparams .lambda_privacy * self .privacy_loss (
396- batch , generated_batch
397- )
390+ # add l1 regularization loss
391+ g_loss += self .hparams .l1_g * self .l1_reg (self .generator )
398392
399- # add l1 regularization loss
400- g_loss += self .hparams .l1_g * self .l1_reg (self .generator )
393+ if len (self .dag_seed ) == 0 :
394+ if self .hparams .grad_dag_loss :
395+ g_loss += self .gradient_dag_loss (batch , z )
396+ if torch .isnan (g_loss ).sum () != 0 :
397+ raise ValueError ("NaN in the gen loss" )
401398
402- if len (self .dag_seed ) == 0 :
403- if self .hparams .grad_dag_loss :
404- g_loss += self .gradient_dag_loss (batch , z )
405- if torch .isnan (g_loss ).sum () != 0 :
406- raise ValueError ("NaN in the gen loss" )
407- self .manual_backward (g_loss )
408- optimizer_g .step ()
409- optimizer_g .zero_grad ()
410- self .untoggle_optimizer (optimizer_g )
399+ return g_loss
400+ else :
401+ raise ValueError ("should not get here" )
411402
412403 def configure_optimizers (self ) -> tuple :
413404 lr = self .hparams .lr
414405 b1 = self .hparams .b1
415406 b2 = self .hparams .b2
416407 weight_decay = self .hparams .weight_decay
417408
418- self .automatic_optimization = False
419409 opt_g = torch .optim .AdamW (
420410 self .generator .parameters (),
421411 lr = lr ,
0 commit comments