1- import os , pprint , time
1+ """ TensorLayer implementation of Deep Convolutional Generative Adversarial Network (DCGAN).
2+ Using deep convolutional generative adversarial networks (DCGAN)
3+ to generate face images from a noise distribution.
4+ References:
5+ -Generative Adversarial Nets.
6+ Goodfellow et al. arXiv: 1406.2661.
7+ - Unsupervised Representation Learning with Deep Convolutional
8+ Generative Adversarial Networks. A Radford, L Metz, S Chintala.
9+ arXiv: 1511.06434.
10+ Links:
11+ - [GAN Paper](https://arxiv.org/pdf/1406.2661.pdf)
12+ - [DCGAN Paper](https://arxiv.org/abs/1511.06434)
13+ Usage:
14+ - See README.md
15+ """
16+ import os
17+ import time
18+
219import numpy as np
320import tensorflow as tf
421import tensorlayer as tl
5- from tensorlayer . layers import *
22+
623from glob import glob
724from random import shuffle
8- from model import *
9- from utils import *
10-
11- pp = pprint .PrettyPrinter ()
1225
13- """
14- TensorLayer implementation of DCGAN to generate face image.
26+ from model import generator_simplified_api , discriminator_simplified_api
27+ from utils import get_image
1528
16- Usage : see README.md
17- """
29+ # Defile TF Flags
1830flags = tf .app .flags
1931flags .DEFINE_integer ("epoch" , 25 , "Epoch to train [25]" )
2032flags .DEFINE_float ("learning_rate" , 0.0002 , "Learning rate of for adam [0.0002]" )
2133flags .DEFINE_float ("beta1" , 0.5 , "Momentum term of adam [0.5]" )
22- flags .DEFINE_integer ("train_size" , np .inf , "The size of train images [np.inf]" )
34+ flags .DEFINE_float ("train_size" , np .inf , "The size of train images [np.inf]" )
2335flags .DEFINE_integer ("batch_size" , 64 , "The number of batch images [64]" )
2436flags .DEFINE_integer ("image_size" , 108 , "The size of image to use (will be center cropped) [108]" )
2537flags .DEFINE_integer ("output_size" , 64 , "The size of the output images to produce [64]" )
3648FLAGS = flags .FLAGS
3749
3850def main (_ ):
39- pp .pprint (flags .FLAGS .__flags )
51+ # Print flags
52+ for flag , _ in FLAGS .__flags .items ():
53+ print ('"{}": {}' .format (flag , getattr (FLAGS , flag )))
54+ print ("--------------------" )
4055
56+ # Configure checkpoint/samples dir
4157 tl .files .exists_or_mkdir (FLAGS .checkpoint_dir )
4258 tl .files .exists_or_mkdir (FLAGS .sample_dir )
4359
44- z_dim = 100
60+ z_dim = 100 # noise dim
61+
62+ # Construct graph on GPU
4563 with tf .device ("/gpu:0" ):
46- ##========================= DEFINE MODEL ===========================##
64+
65+ """ Define Models """
4766 z = tf .placeholder (tf .float32 , [FLAGS .batch_size , z_dim ], name = 'z_noise' )
4867 real_images = tf .placeholder (tf .float32 , [FLAGS .batch_size , FLAGS .output_size , FLAGS .output_size , FLAGS .c_dim ], name = 'real_images' )
4968
50- # z --> generator for training
69+ # Input noise into generator for training
5170 net_g , g_logits = generator_simplified_api (z , is_train = True , reuse = False )
52- # generated fake images --> discriminator
71+
72+ # Input real and generated fake images into discriminator for training
5373 net_d , d_logits = discriminator_simplified_api (net_g .outputs , is_train = True , reuse = False )
54- # real images --> discriminator
5574 net_d2 , d2_logits = discriminator_simplified_api (real_images , is_train = True , reuse = True )
56- # sample_z --> generator for evaluation, set is_train to False
57- # so that BatchNormLayer behave differently
75+
76+ # Input noise into generator for evaluation
77+ # set is_train to False so that BatchNormLayer behave differently
5878 net_g2 , g2_logits = generator_simplified_api (z , is_train = False , reuse = True )
5979
60- ##========================= DEFINE TRAIN OPS =======================##
80+ """ Define Training Operations """
6181 # cost for updating discriminator and generator
6282 # discriminator: real images are labelled as 1
6383 d_loss_real = tl .cost .sigmoid_cross_entropy (d2_logits , tf .ones_like (d2_logits ), name = 'dreal' )
84+
6485 # discriminator: images from generator (fake) are labelled as 0
6586 d_loss_fake = tl .cost .sigmoid_cross_entropy (d_logits , tf .zeros_like (d_logits ), name = 'dfake' )
6687 d_loss = d_loss_real + d_loss_fake
88+
6789 # generator: try to make the the fake images look real (1)
6890 g_loss = tl .cost .sigmoid_cross_entropy (d_logits , tf .ones_like (d_logits ), name = 'gfake' )
6991
7092 g_vars = tl .layers .get_variables_with_name ('generator' , True , True )
7193 d_vars = tl .layers .get_variables_with_name ('discriminator' , True , True )
7294
73- net_g .print_params (False )
74- print ("---------------" )
75- net_d .print_params (False )
76-
77- # optimizers for updating discriminator and generator
95+ # Define optimizers for updating discriminator and generator
7896 d_optim = tf .train .AdamOptimizer (FLAGS .learning_rate , beta1 = FLAGS .beta1 ) \
7997 .minimize (d_loss , var_list = d_vars )
8098 g_optim = tf .train .AdamOptimizer (FLAGS .learning_rate , beta1 = FLAGS .beta1 ) \
8199 .minimize (g_loss , var_list = g_vars )
82100
101+ # Init Session
83102 sess = tf .InteractiveSession ()
84- tl . layers . initialize_global_variables ( sess )
103+ sess . run ( tf . global_variables_initializer () )
85104
86105 model_dir = "%s_%s_%s" % (FLAGS .dataset , FLAGS .batch_size , FLAGS .output_size )
87106 save_dir = os .path .join (FLAGS .checkpoint_dir , model_dir )
88107 tl .files .exists_or_mkdir (FLAGS .sample_dir )
89108 tl .files .exists_or_mkdir (save_dir )
109+
90110 # load the latest checkpoints
91111 net_g_name = os .path .join (save_dir , 'net_g.npz' )
92112 net_d_name = os .path .join (save_dir , 'net_d.npz' )
@@ -95,50 +115,57 @@ def main(_):
95115
96116 sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
97117
98- ##========================= TRAIN MODELS ================================##
118+ """ Training models """
99119 iter_counter = 0
100120 for epoch in range (FLAGS .epoch ):
101- ## shuffle data
121+
122+ # Shuffle data
102123 shuffle (data_files )
103124
104- ## update sample files based on shuffled data
125+ # Update sample files based on shuffled data
105126 sample_files = data_files [0 :FLAGS .sample_size ]
106127 sample = [get_image (sample_file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 ) for sample_file in sample_files ]
107128 sample_images = np .array (sample ).astype (np .float32 )
108129 print ("[*] Sample images updated!" )
109130
110- ## load image data
131+ # Load image data
111132 batch_idxs = min (len (data_files ), FLAGS .train_size ) // FLAGS .batch_size
112133
113134 for idx in range (0 , batch_idxs ):
114- batch_files = data_files [idx * FLAGS .batch_size :(idx + 1 ) * FLAGS .batch_size ]
115- ## get real images
116- # more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
135+ batch_files = data_files [idx * FLAGS .batch_size :(idx + 1 ) * FLAGS .batch_size ]
136+
137+ # Get real images ( more image augmentation functions at [ http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
117138 batch = [get_image (batch_file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 ) for batch_file in batch_files ]
118139 batch_images = np .array (batch ).astype (np .float32 )
119- batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 ) # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
140+ batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
120141 start_time = time .time ()
121- # updates the discriminator
142+
143+ # Updates the Discriminator(D)
122144 errD , _ = sess .run ([d_loss , d_optim ], feed_dict = {z : batch_z , real_images : batch_images })
123- # updates the generator, run generator twice to make sure that d_loss does not go to zero (difference from paper)
145+
146+ # Updates the Generator(G)
147+ # run generator twice to make sure that d_loss does not go to zero (different from paper)
124148 for _ in range (2 ):
125149 errG , _ = sess .run ([g_loss , g_optim ], feed_dict = {z : batch_z })
126150 print ("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
127151 % (epoch , FLAGS .epoch , idx , batch_idxs , time .time () - start_time , errD , errG ))
128152
129153 iter_counter += 1
130154 if np .mod (iter_counter , FLAGS .sample_step ) == 0 :
131- # generate and visualize generated images
155+ # Generate images
132156 img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
157+ # Visualize generated images
133158 tl .visualize .save_images (img , [8 , 8 ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
134159 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
135160
136161 if np .mod (iter_counter , FLAGS .save_step ) == 0 :
137- # save current network parameters
162+ # Save current network parameters
138163 print ("[*] Saving checkpoints..." )
139164 tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
140165 tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
141166 print ("[*] Saving checkpoints SUCCESS!" )
167+
168+ sess .close ()
142169
143170if __name__ == '__main__' :
144171 tf .app .run ()
0 commit comments