@@ -137,8 +137,12 @@ def should_stop(step):
137137
138138 while True :
139139 t0 = time .time ()
140- train_step , _ , global_loss = self .session .run (
141- (self .model .global_step , train_op , loss ),
140+ train_step , _ = self .session .run (
141+ (self .model .global_step , train_op ),
142+ feed_dict = feed_dict
143+ )
144+ global_loss = self .session .run (
145+ (loss ),
142146 feed_dict = feed_dict
143147 )
144148 t1 = time .time ()
@@ -232,8 +236,12 @@ def train(self, *args,
232236
233237 while train_step < stopping_criteria :
234238 t0 = time .time ()
235- train_step , _ , global_loss = self .session .run (
236- (self .model .global_step , train_op , loss ),
239+ train_step , _ = self .session .run (
240+ (self .model .global_step , train_op ),
241+ feed_dict = feed_dict
242+ )
243+ global_loss = self .session .run (
244+ (loss ),
237245 feed_dict = feed_dict
238246 )
239247 t1 = time .time ()
@@ -266,8 +274,12 @@ def train(self, *args,
266274 # Update convergence metric reference:
267275 t0 = time .time ()
268276 metric_prev = metric_current
269- train_step , _ , global_loss = self .session .run (
270- (self .model .global_step , train_op , loss ),
277+ train_step , _ = self .session .run (
278+ (self .model .global_step , train_op ),
279+ feed_dict = feed_dict
280+ )
281+ global_loss = self .session .run (
282+ (loss ),
271283 feed_dict = feed_dict
272284 )
273285 # Evaluate convergence metric:
0 commit comments