@@ -16,41 +16,44 @@ def train():
1616 G .train ()
1717 D .train ()
1818
19- d_optimizer = tf .optimizers .Adam (flags .learning_rate , beta_1 = flags .beta1 )
20- g_optimizer = tf .optimizers .Adam (flags .learning_rate , beta_1 = flags .beta1 )
19+ d_optimizer = tf .optimizers .Adam (flags .lr , beta_1 = flags .beta1 )
20+ g_optimizer = tf .optimizers .Adam (flags .lr , beta_1 = flags .beta1 )
2121
2222 n_step_epoch = int (len (images_path ) // flags .batch_size )
23-
24- for step , batch_images in enumerate (images ):
25- step_time = time .time ()
26- with tf .GradientTape (persistent = True ) as tape :
27- # z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
28- z = np .random .normal (loc = 0.0 , scale = 1.0 , size = [flags .batch_size , flags .z_dim ]).astype (np .float32 )
29- d_logits = D (G (z ))
30- d2_logits = D (batch_images )
31- # discriminator: real images are labelled as 1
32- d_loss_real = tl .cost .sigmoid_cross_entropy (d2_logits , tf .ones_like (d2_logits ), name = 'dreal' )
33- # discriminator: images from generator (fake) are labelled as 0
34- d_loss_fake = tl .cost .sigmoid_cross_entropy (d_logits , tf .zeros_like (d_logits ), name = 'dfake' )
35- # combined loss for updating discriminator
36- d_loss = d_loss_real + d_loss_fake
37- # generator: try to fool discriminator to output 1
38- g_loss = tl .cost .sigmoid_cross_entropy (d_logits , tf .ones_like (d_logits ), name = 'gfake' )
39-
40- grad = tape .gradient (g_loss , G .trainable_weights )
41- g_optimizer .apply_gradients (zip (grad , G .trainable_weights ))
42- grad = tape .gradient (d_loss , D .trainable_weights )
43- d_optimizer .apply_gradients (zip (grad , D .trainable_weights ))
44- del tape
45-
46- print ("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}" .format (step // n_step_epoch , flags .n_epoch , step , n_step_epoch , time .time ()- step_time , d_loss , g_loss ))
47- if np .mod (step , flags .save_step ) == 0 :
23+
24+ for epoch in range (flags .n_epoch ):
25+ for step , batch_images in enumerate (images ):
26+ step_time = time .time ()
27+ with tf .GradientTape (persistent = True ) as tape :
28+ # z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
29+ z = np .random .normal (loc = 0.0 , scale = 1.0 , size = [flags .batch_size , flags .z_dim ]).astype (np .float32 )
30+ d_logits = D (G (z ))
31+ d2_logits = D (batch_images )
32+ # discriminator: real images are labelled as 1
33+ d_loss_real = tl .cost .sigmoid_cross_entropy (d2_logits , tf .ones_like (d2_logits ), name = 'dreal' )
34+ # discriminator: images from generator (fake) are labelled as 0
35+ d_loss_fake = tl .cost .sigmoid_cross_entropy (d_logits , tf .zeros_like (d_logits ), name = 'dfake' )
36+ # combined loss for updating discriminator
37+ d_loss = d_loss_real + d_loss_fake
38+ # generator: try to fool discriminator to output 1
39+ g_loss = tl .cost .sigmoid_cross_entropy (d_logits , tf .ones_like (d_logits ), name = 'gfake' )
40+
41+ grad = tape .gradient (g_loss , G .trainable_weights )
42+ g_optimizer .apply_gradients (zip (grad , G .trainable_weights ))
43+ grad = tape .gradient (d_loss , D .trainable_weights )
44+ d_optimizer .apply_gradients (zip (grad , D .trainable_weights ))
45+ del tape
46+
47+ print ("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}" .format (epoch , \
48+ flags .n_epoch , step , n_step_epoch , time .time ()- step_time , d_loss , g_loss ))
49+
50+ if np .mod (epoch , flags .save_every_epoch ) == 0 :
4851 G .save_weights ('{}/G.npz' .format (flags .checkpoint_dir ), format = 'npz' )
4952 D .save_weights ('{}/D.npz' .format (flags .checkpoint_dir ), format = 'npz' )
5053 G .eval ()
5154 result = G (z )
5255 G .train ()
53- tl .visualize .save_images (result .numpy (), [num_tiles , num_tiles ], '{}/train_{:02d}_{:04d} .png' .format (flags .sample_dir , step // n_step_epoch , step ))
56+ tl .visualize .save_images (result .numpy (), [num_tiles , num_tiles ], '{}/train_{:02d}.png' .format (flags .sample_dir , epoch ))
5457
5558if __name__ == '__main__' :
5659 train ()
0 commit comments