Skip to content

Commit 5bf90fb

Browse files
Internal change
PiperOrigin-RevId: 397848064
1 parent 7f69eb3 commit 5bf90fb

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

official/modeling/optimization/optimizer_factory.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Optimizer factory class."""
16-
from typing import Callable, Optional, Union
16+
from typing import Callable, Optional, Union, List, Tuple
1717

1818
import gin
1919
import tensorflow as tf
@@ -139,6 +139,9 @@ def build_learning_rate(self):
139139
def build_optimizer(
140140
self,
141141
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
142+
gradient_transformers: Optional[List[Callable[
143+
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor, tf.Tensor]]
144+
]]] = None,
142145
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
143146
tf.keras.optimizers.Optimizer]] = None):
144147
"""Build optimizer.
@@ -150,6 +153,11 @@ def build_optimizer(
150153
Args:
151154
lr: A floating point value, or a
152155
tf.keras.optimizers.schedules.LearningRateSchedule instance.
156+
gradient_transformers: Optional list of functions to use to transform
157+
gradients before applying updates to Variables. The functions are
158+
applied after gradient_aggregator. The functions should accept and
159+
return a list of (gradient, variable) tuples. clipvalue, clipnorm,
160+
global_clipnorm should not be set when gradient_transformers is passed.
153161
postprocessor: An optional function for postprocessing the optimizer. It
154162
takes an optimizer and returns an optimizer.
155163
@@ -158,13 +166,17 @@ def build_optimizer(
158166
"""
159167

160168
optimizer_dict = self._optimizer_config.as_dict()
161-
## Delete clipnorm and clipvalue if None
169+
## Delete clipnorm, clipvalue, global_clipnorm if None
162170
if optimizer_dict['clipnorm'] is None:
163171
del optimizer_dict['clipnorm']
164172
if optimizer_dict['clipvalue'] is None:
165173
del optimizer_dict['clipvalue']
174+
if optimizer_dict['global_clipnorm'] is None:
175+
del optimizer_dict['global_clipnorm']
166176

167177
optimizer_dict['learning_rate'] = lr
178+
if gradient_transformers is not None:
179+
optimizer_dict['gradient_transformers'] = gradient_transformers
168180

169181
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
170182

0 commit comments

Comments
 (0)