11import abc
22from typing import Dict , Any , Union , List , Iterable
3-
43import os
4+ import time
55import datetime
66
77import numpy as np
@@ -127,12 +127,19 @@ def should_stop(step):
127127 return False
128128
129129 while True :
130+ t0 = time .time ()
130131 train_step , global_loss , _ = self .session .run (
131132 (self .model .global_step , loss , train_op ),
132133 feed_dict = feed_dict
133134 )
135+ t1 = time .time ()
134136
135- tf .logging .info ("Step: %d\t loss: %f" , train_step , global_loss )
137+ tf .logging .info (
138+ "Step: \t %d\t loss: %f\t in %s sec" ,
139+ train_step ,
140+ global_loss ,
141+ str (np .round (t1 - t0 , 3 ))
142+ )
136143
137144 # update last_loss every N+1st step:
138145 if train_step % len (loss_hist ) == 1 :
@@ -205,12 +212,19 @@ def train(self, *args,
205212 if convergence_criteria == "step" :
206213 train_step = self .session .run (self .model .global_step , feed_dict = feed_dict )
207214 while train_step < stopping_criteria :
215+ t0 = time .time ()
208216 train_step , global_loss , _ = self .session .run (
209217 (self .model .global_step , loss , train_op ),
210218 feed_dict = feed_dict
211219 )
220+ t1 = time .time ()
212221
213- tf .logging .info ("Step: %d\t loss: %f" , train_step , global_loss )
222+ tf .logging .info (
223+ "Step: %d\t loss: %s" ,
224+ train_step ,
225+ global_loss ,
226+ str (np .round (t1 - t0 , 3 ))
227+ )
214228 elif convergence_criteria in ["all_converged_ll" , "all_converged_theta" ]:
215229 # Evaluate initial value of convergence metric:
216230 if convergence_criteria == "all_converged_theta" :
@@ -222,6 +236,7 @@ def train(self, *args,
222236
223237 while np .any (self .model .model_vars .converged == False ):
224238 # Update convergence metric reference:
239+ t0 = time .time ()
225240 metric_prev = metric_current
226241 train_step , global_loss , _ = self .session .run (
227242 (self .model .global_step , loss , train_op ),
@@ -244,12 +259,14 @@ def train(self, *args,
244259 self .model .model_vars .converged ,
245260 metric_delta < stopping_criteria
246261 )
262+ t1 = time .time ()
247263
248264 tf .logging .info (
249- "Step: \t %d\t loss: %f\t models converged %i " ,
265+ "Step: \t %d\t loss: \t %f\t models converged \t %i \t in %s sec " ,
250266 train_step ,
251267 global_loss ,
252- np .sum (self .model .model_vars .converged ).astype ("int32" )
268+ np .sum (self .model .model_vars .converged ).astype ("int32" ),
269+ str (np .round (t1 - t0 , 3 ))
253270 )
254271 else :
255272 self ._train_to_convergence (
0 commit comments