13
13
# limitations under the License.
14
14
15
15
"""Optimizer factory class."""
16
- from typing import Callable , Optional , Union
16
+ from typing import Callable , Optional , Union , List , Tuple
17
17
18
18
import gin
19
19
import tensorflow as tf
@@ -139,6 +139,9 @@ def build_learning_rate(self):
139
139
def build_optimizer (
140
140
self ,
141
141
lr : Union [tf .keras .optimizers .schedules .LearningRateSchedule , float ],
142
+ gradient_transformers : Optional [List [Callable [
143
+ [List [Tuple [tf .Tensor , tf .Tensor ]]], List [Tuple [tf .Tensor , tf .Tensor ]]
144
+ ]]] = None ,
142
145
postprocessor : Optional [Callable [[tf .keras .optimizers .Optimizer ],
143
146
tf .keras .optimizers .Optimizer ]] = None ):
144
147
"""Build optimizer.
@@ -150,6 +153,11 @@ def build_optimizer(
150
153
Args:
151
154
lr: A floating point value, or a
152
155
tf.keras.optimizers.schedules.LearningRateSchedule instance.
156
+ gradient_transformers: Optional list of functions to use to transform
157
+ gradients before applying updates to Variables. The functions are
158
+ applied after gradient_aggregator. The functions should accept and
159
+ return a list of (gradient, variable) tuples. clipvalue, clipnorm,
160
+ global_clipnorm should not be set when gradient_transformers is passed.
153
161
postprocessor: An optional function for postprocessing the optimizer. It
154
162
takes an optimizer and returns an optimizer.
155
163
@@ -158,13 +166,17 @@ def build_optimizer(
158
166
"""
159
167
160
168
optimizer_dict = self ._optimizer_config .as_dict ()
161
- ## Delete clipnorm and clipvalue if None
169
+ ## Delete clipnorm, clipvalue, global_clipnorm if None
162
170
if optimizer_dict ['clipnorm' ] is None :
163
171
del optimizer_dict ['clipnorm' ]
164
172
if optimizer_dict ['clipvalue' ] is None :
165
173
del optimizer_dict ['clipvalue' ]
174
+ if optimizer_dict ['global_clipnorm' ] is None :
175
+ del optimizer_dict ['global_clipnorm' ]
166
176
167
177
optimizer_dict ['learning_rate' ] = lr
178
+ if gradient_transformers is not None :
179
+ optimizer_dict ['gradient_transformers' ] = gradient_transformers
168
180
169
181
optimizer = OPTIMIZERS_CLS [self ._optimizer_type ](** optimizer_dict )
170
182
0 commit comments