Skip to content

Commit 00756f3

Browse files
fixed the global loss is evaluated after train_op
this may cause more likelihood evaluations
1 parent e3dfa36 commit 00756f3

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

batchglm/train/tf/base/estimator.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,12 @@ def should_stop(step):
137137

138138
while True:
139139
t0 = time.time()
140-
train_step, _, global_loss = self.session.run(
141-
(self.model.global_step, train_op, loss),
140+
train_step, _ = self.session.run(
141+
(self.model.global_step, train_op),
142+
feed_dict=feed_dict
143+
)
144+
global_loss = self.session.run(
145+
(loss),
142146
feed_dict=feed_dict
143147
)
144148
t1 = time.time()
@@ -232,8 +236,12 @@ def train(self, *args,
232236

233237
while train_step < stopping_criteria:
234238
t0 = time.time()
235-
train_step, _, global_loss = self.session.run(
236-
(self.model.global_step, train_op, loss),
239+
train_step, _ = self.session.run(
240+
(self.model.global_step, train_op),
241+
feed_dict=feed_dict
242+
)
243+
global_loss = self.session.run(
244+
(loss),
237245
feed_dict=feed_dict
238246
)
239247
t1 = time.time()
@@ -266,8 +274,12 @@ def train(self, *args,
266274
# Update convergence metric reference:
267275
t0 = time.time()
268276
metric_prev = metric_current
269-
train_step, _, global_loss = self.session.run(
270-
(self.model.global_step, train_op, loss),
277+
train_step, _ = self.session.run(
278+
(self.model.global_step, train_op),
279+
feed_dict=feed_dict
280+
)
281+
global_loss = self.session.run(
282+
(loss),
271283
feed_dict=feed_dict
272284
)
273285
# Evaluate convergence metric:

0 commit comments

Comments
 (0)