Skip to content

Commit e60551b

Browse files
authored
Fix the counter if the worker retry to aggregate gradient in ElasticDL (#2497)
* Fix the counter when retring * Fix by comments * Polish
1 parent 0c00b66 commit e60551b

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

elasticai_api/tensorflow/optimizer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,17 @@ def compute_gradients(self, grads):
330330
grads = get_not_none_from_list(grads)
331331
assert len(grads) == len(self.locally_aggregated_grads)
332332

333-
# Allreduce locally aggregated gradients when the counter is
334-
# equivalent to `backward_passes_per_step`. This the condition is
335-
# true, it also resets the counter back to 0.
333+
# Allreduce locally aggregated gradients when the counter equals
334+
# or exceeds backward_passes_per_step. The counter may exceed
335+
# backward_passes_per_step because of retries in the fault-tolerant
336+
# allreduce. When the condition is true, it also resets the counter
337+
# back to 0.
336338
allreduced_grads = tf.cond(
337-
tf.equal(
339+
tf.math.less(
338340
self.counter, self.mutable_local_backward_passes_per_step
339341
),
340-
lambda: self._allreduce_grads_helper(grads),
341342
lambda: grads,
343+
lambda: self._allreduce_grads_helper(grads),
342344
)
343345

344346
# Handle case where there is only one variable.

0 commit comments

Comments
 (0)