File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -330,15 +330,17 @@ def compute_gradients(self, grads):
330
330
grads = get_not_none_from_list (grads )
331
331
assert len (grads ) == len (self .locally_aggregated_grads )
332
332
333
- # Allreduce locally aggregated gradients when the counter is
334
- # equivalent to `backward_passes_per_step`. This the condition is
335
- # true, it also resets the counter back to 0.
333
+ # Allreduce locally aggregated gradients when the counter equals
334
+ # or exceeds backward_passes_per_step. The counter may exceed
335
+ # backward_passes_per_step because of retries in the fault-tolerant
336
+ # allreduce. When the condition is true, it also resets the counter
337
+ # back to 0.
336
338
allreduced_grads = tf .cond (
337
- tf .equal (
339
+ tf .math . less (
338
340
self .counter , self .mutable_local_backward_passes_per_step
339
341
),
340
- lambda : self ._allreduce_grads_helper (grads ),
341
342
lambda : grads ,
343
+ lambda : self ._allreduce_grads_helper (grads ),
342
344
)
343
345
344
346
# Handle case where there is only one variable.
You can’t perform that action at this time.
0 commit comments