4747FLAGS = flags .FLAGS
4848
4949def main (_ ):
50+ assert np .sqrt (FLAGS .sample_size ) % 1 == 0. , 'Flag `sample_size` needs to be a perfect square'
51+ num_tiles = int (np .sqrt (FLAGS .sample_size ))
52+
5053 # Print flags
5154 for flag , _ in FLAGS .__flags .items ():
5255 print ('"{}": {}' .format (flag , getattr (FLAGS , flag )))
@@ -62,8 +65,8 @@ def main(_):
6265 with tf .device ("/gpu:0" ):
6366
6467 """ Define Models """
65- z = tf .placeholder (tf .float32 , [FLAGS . batch_size , z_dim ], name = 'z_noise' )
66- real_images = tf .placeholder (tf .float32 , [FLAGS . batch_size , FLAGS .output_size , FLAGS .output_size , FLAGS .c_dim ], name = 'real_images' )
68+ z = tf .placeholder (tf .float32 , [None , z_dim ], name = 'z_noise' )
69+ real_images = tf .placeholder (tf .float32 , [None , FLAGS .output_size , FLAGS .output_size , FLAGS .c_dim ], name = 'real_images' )
6770
6871 # Input noise into generator for training
6972 net_g = generator (z , is_train = True , reuse = False )
@@ -77,12 +80,11 @@ def main(_):
7780 net_g2 = generator (z , is_train = False , reuse = True )
7881
7982 """ Define Training Operations """
80- # cost for updating discriminator and generator
8183 # discriminator: real images are labelled as 1
8284 d_loss_real = tl .cost .sigmoid_cross_entropy (d2_logits , tf .ones_like (d2_logits ), name = 'dreal' )
83-
8485 # discriminator: images from generator (fake) are labelled as 0
8586 d_loss_fake = tl .cost .sigmoid_cross_entropy (d_logits , tf .zeros_like (d_logits ), name = 'dfake' )
87+ # cost for updating discriminator
8688 d_loss = d_loss_real + d_loss_fake
8789
8890 # generator: try to make the the fake images look real (1)
@@ -112,17 +114,16 @@ def main(_):
112114
113115 data_files = np .array (glob (os .path .join ("./data" , FLAGS .dataset , "*.jpg" )))
114116 num_files = len (data_files )
115- shuffle = True
116117
117118 # Mini-batch generator
118- def iterate_minibatches ():
119+ def iterate_minibatches (batch_size , shuffle = True ):
119120 if shuffle :
120121 indices = np .random .permutation (num_files )
121- for start_idx in range (0 , num_files - FLAGS . batch_size + 1 , FLAGS . batch_size ):
122+ for start_idx in range (0 , num_files - batch_size + 1 , batch_size ):
122123 if shuffle :
123- excerpt = indices [start_idx : start_idx + FLAGS . batch_size ]
124+ excerpt = indices [start_idx : start_idx + batch_size ]
124125 else :
125- excerpt = slice (start_idx , start_idx + FLAGS . batch_size )
126+ excerpt = slice (start_idx , start_idx + batch_size )
126127 # Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
127128 yield np .array ([get_image (file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 )
128129 for file in data_files [excerpt ]]).astype (np .float32 )
@@ -136,13 +137,13 @@ def iterate_minibatches():
136137 iter_counter = 0
137138 for epoch in range (FLAGS .epoch ):
138139
139- sample_images = next (iterate_minibatches ())
140+ sample_images = next (iterate_minibatches (FLAGS . sample_size ))
140141 print ("[*] Sample images updated!" )
141142
142143 steps = 0
143- for batch_images in iterate_minibatches ():
144+ for batch_images in iterate_minibatches (FLAGS . batch_size ):
144145
145- batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
146+ batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .batch_size , z_dim )).astype (np .float32 )
146147 start_time = time .time ()
147148
148149 # Updates the Discriminator(D)
@@ -162,7 +163,7 @@ def iterate_minibatches():
162163 # Generate images
163164 img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
164165 # Visualize generated images
165- tl .visualize .save_images (img , [8 , 8 ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , steps ))
166+ tl .visualize .save_images (img , [num_tiles , num_tiles ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , steps ))
166167 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
167168
168169 if np .mod (iter_counter , FLAGS .save_step ) == 0 :
@@ -171,10 +172,13 @@ def iterate_minibatches():
171172 tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
172173 tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
173174 print ("[*] Saving checkpoints SUCCESS!" )
174-
175+
175176 steps += 1
176-
177+
177178 sess .close ()
178179
179180if __name__ == '__main__' :
180- tf .app .run ()
181+ try :
182+ tf .app .run ()
183+ except KeyboardInterrupt :
184+ print ('EXIT' )
0 commit comments