Skip to content

Commit 245c5c3

Browse files
crccwaman2930
authored andcommitted
Rename all_reduce_sum_gradients to experimental_aggregate_gradients
For some strategies we don't do all reduce, so all_reduce_sum_gradients can be misleading. The parameter is also changed to experimental because of issues with CentralStorageStrategy. PiperOrigin-RevId: 302734837
1 parent d05c430 commit 245c5c3

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

official/nlp/optimization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,19 @@ def _decay_weights_op(self, var, learning_rate, apply_state):
140140
def apply_gradients(self,
141141
grads_and_vars,
142142
name=None,
143-
all_reduce_sum_gradients=True):
143+
experimental_aggregate_gradients=True):
144144
grads, tvars = list(zip(*grads_and_vars))
145-
if all_reduce_sum_gradients:
146-
# when all_reduce_sum_gradients = False, apply_gradients() no longer
147-
# implicitly allreduce gradients, users manually allreduce gradient and
148-
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm
149-
# will be moved to before the explicit allreduce to keep the math
150-
# the same as TF 1 and pre TF 2.2 implementation.
145+
if experimental_aggregate_gradients:
146+
# when experimental_aggregate_gradients = False, apply_gradients() no
147+
# longer implicitly allreduce gradients, users manually allreduce gradient
148+
# and passed the allreduced grads_and_vars. For now, the
149+
# clip_by_global_norm will be moved to before the explicit allreduce to
150+
# keep the math the same as TF 1 and pre TF 2.2 implementation.
151151
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
152152
return super(AdamWeightDecay, self).apply_gradients(
153153
zip(grads, tvars),
154154
name=name,
155-
all_reduce_sum_gradients=all_reduce_sum_gradients)
155+
experimental_aggregate_gradients=experimental_aggregate_gradients)
156156

157157
def _get_lr(self, var_device, var_dtype, apply_state):
158158
"""Retrieves the learning rate with the given state."""

official/staging/training/grad_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars,
5454
This utils function is used when users intent to explicitly allreduce
5555
gradients and customize gradients operations before and after allreduce.
5656
The allreduced gradients are then passed to optimizer.apply_gradients(
57-
all_reduce_sum_gradients=False).
57+
experimental_aggregate_gradients=False).
5858
5959
Arguments:
6060
grads_and_vars: gradients and variables pairs.
@@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape,
139139
grads_and_vars = zip(allreduced_grads, filtered_training_vars)
140140
if post_allreduce_callbacks:
141141
grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
142-
optimizer.apply_gradients(grads_and_vars, all_reduce_sum_gradients=False)
142+
optimizer.apply_gradients(
143+
grads_and_vars, experimental_aggregate_gradients=False)

0 commit comments

Comments
 (0)