2121import tensorlayer as tl
2222
2323from glob import glob
24- from random import shuffle
2524
26- from model import generator_simplified_api , discriminator_simplified_api
2725from utils import get_image
26+ from model import generator , discriminator
2827
29- # Defile TF Flags
28+ # Define TF Flags
3029flags = tf .app .flags
3130flags .DEFINE_integer ("epoch" , 25 , "Epoch to train [25]" )
3231flags .DEFINE_float ("learning_rate" , 0.0002 , "Learning rate of for adam [0.0002]" )
@@ -67,15 +66,15 @@ def main(_):
6766 real_images = tf .placeholder (tf .float32 , [FLAGS .batch_size , FLAGS .output_size , FLAGS .output_size , FLAGS .c_dim ], name = 'real_images' )
6867
6968 # Input noise into generator for training
70- net_g , g_logits = generator_simplified_api (z , is_train = True , reuse = False )
69+ net_g = generator (z , is_train = True , reuse = False )
7170
7271 # Input real and generated fake images into discriminator for training
73- net_d , d_logits = discriminator_simplified_api (net_g .outputs , is_train = True , reuse = False )
74- net_d2 , d2_logits = discriminator_simplified_api (real_images , is_train = True , reuse = True )
72+ net_d , d_logits = discriminator (net_g .outputs , is_train = True , reuse = False )
73+ _ , d2_logits = discriminator (real_images , is_train = True , reuse = True )
7574
7675 # Input noise into generator for evaluation
7776 # set is_train to False so that BatchNormLayer behave differently
78- net_g2 , g2_logits = generator_simplified_api (z , is_train = False , reuse = True )
77+ net_g2 = generator (z , is_train = False , reuse = True )
7978
8079 """ Define Training Operations """
8180 # cost for updating discriminator and generator
@@ -111,51 +110,59 @@ def main(_):
111110 net_g_name = os .path .join (save_dir , 'net_g.npz' )
112111 net_d_name = os .path .join (save_dir , 'net_d.npz' )
113112
114- data_files = glob (os .path .join ("./data" , FLAGS .dataset , "*.jpg" ))
113+ data_files = np .array (glob (os .path .join ("./data" , FLAGS .dataset , "*.jpg" )))
114+ num_files = len (data_files )
115+ shuffle = True
116+
117+ # Mini-batch generator
118+ def iterate_minibatches ():
119+ if shuffle :
120+ indices = np .random .permutation (num_files )
121+ for start_idx in range (0 , num_files - FLAGS .batch_size + 1 , FLAGS .batch_size ):
122+ if shuffle :
123+ excerpt = indices [start_idx : start_idx + FLAGS .batch_size ]
124+ else :
125+ excerpt = slice (start_idx , start_idx + FLAGS .batch_size )
126+ # Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
127+ yield np .array ([get_image (file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 )
128+ for file in data_files [excerpt ]]).astype (np .float32 )
129+
130+ batch_steps = min (num_files , FLAGS .train_size ) // FLAGS .batch_size
115131
116- sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
132+ # sample noise
133+ sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
117134
118135 """ Training models """
119136 iter_counter = 0
120137 for epoch in range (FLAGS .epoch ):
121138
122- # Shuffle data
123- shuffle (data_files )
124-
125- # Update sample files based on shuffled data
126- sample_files = data_files [0 :FLAGS .sample_size ]
127- sample = [get_image (sample_file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 ) for sample_file in sample_files ]
128- sample_images = np .array (sample ).astype (np .float32 )
139+ sample_images = next (iterate_minibatches ())
129140 print ("[*] Sample images updated!" )
141+
142+ steps = 0
143+ for batch_images in iterate_minibatches ():
130144
131- # Load image data
132- batch_idxs = min (len (data_files ), FLAGS .train_size ) // FLAGS .batch_size
133-
134- for idx in range (0 , batch_idxs ):
135- batch_files = data_files [idx * FLAGS .batch_size :(idx + 1 ) * FLAGS .batch_size ]
136-
137- # Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
138- batch = [get_image (batch_file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 ) for batch_file in batch_files ]
139- batch_images = np .array (batch ).astype (np .float32 )
140145 batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
141146 start_time = time .time ()
142147
143148 # Updates the Discriminator(D)
144- errD , _ = sess .run ([d_loss , d_optim ], feed_dict = {z : batch_z , real_images : batch_images })
149+ errD , _ = sess .run ([d_loss , d_optim ], feed_dict = {z : batch_z , real_images : batch_images })
145150
146151 # Updates the Generator(G)
147152 # run generator twice to make sure that d_loss does not go to zero (different from paper)
148153 for _ in range (2 ):
149154 errG , _ = sess .run ([g_loss , g_optim ], feed_dict = {z : batch_z })
155+
156+ end_time = time .time () - start_time
150157 print ("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
151- % (epoch , FLAGS .epoch , idx , batch_idxs , time . time () - start_time , errD , errG ))
158+ % (epoch , FLAGS .epoch , steps , batch_steps , end_time , errD , errG ))
152159
153160 iter_counter += 1
154161 if np .mod (iter_counter , FLAGS .sample_step ) == 0 :
155162 # Generate images
156- img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
163+ img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
157164 # Visualize generated images
158- tl .visualize .save_images (img , [8 , 8 ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
165+ tl .visualize .save_images (img , [8 , 8 ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , steps ))
159166 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
160167
161168 if np .mod (iter_counter , FLAGS .save_step ) == 0 :
@@ -164,6 +171,8 @@ def main(_):
164171 tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
165172 tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
166173 print ("[*] Saving checkpoints SUCCESS!" )
174+
175+ steps += 1
167176
168177 sess .close ()
169178
0 commit comments