Skip to content

Commit e3dfa36

Browse files
loss evaluated after train_op
this doesnt make a difference yet.
1 parent 4442029 commit e3dfa36

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

batchglm/train/tf/base/estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def should_stop(step):
137137

138138
while True:
139139
t0 = time.time()
140-
train_step, global_loss, _ = self.session.run(
141-
(self.model.global_step, loss, train_op),
140+
train_step, _, global_loss = self.session.run(
141+
(self.model.global_step, train_op, loss),
142142
feed_dict=feed_dict
143143
)
144144
t1 = time.time()
@@ -232,8 +232,8 @@ def train(self, *args,
232232

233233
while train_step < stopping_criteria:
234234
t0 = time.time()
235-
train_step, global_loss, _ = self.session.run(
236-
(self.model.global_step, loss, train_op),
235+
train_step, _, global_loss = self.session.run(
236+
(self.model.global_step, train_op, loss),
237237
feed_dict=feed_dict
238238
)
239239
t1 = time.time()
@@ -266,8 +266,8 @@ def train(self, *args,
266266
# Update convergence metric reference:
267267
t0 = time.time()
268268
metric_prev = metric_current
269-
train_step, global_loss, _ = self.session.run(
270-
(self.model.global_step, loss, train_op),
269+
train_step, _, global_loss = self.session.run(
270+
(self.model.global_step, train_op, loss),
271271
feed_dict=feed_dict
272272
)
273273
# Evaluate convergence metric:

0 commit comments

Comments
 (0)