Skip to content

Commit a7894f9

Browse files
Internal change
PiperOrigin-RevId: 424391275
1 parent 885fda0 commit a7894f9

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

official/modeling/optimization/configs/optimization_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class OptimizerConfig(oneof.OneOfConfig):
4545
"""
4646
type: Optional[str] = None
4747
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
48+
sgd_experimental: opt_cfg.SGDExperimentalConfig = (
49+
opt_cfg.SGDExperimentalConfig())
4850
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
4951
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
5052
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()

official/modeling/optimization/configs/optimizer_config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ class SGDConfig(BaseOptimizerConfig):
5454
momentum: float = 0.0
5555

5656

57+
# TODO(b/216129465): Merge this config with SGDConfig after the experimental
58+
# optimizer graduates.
59+
@dataclasses.dataclass
60+
class SGDExperimentalConfig(BaseOptimizerConfig):
61+
"""Configuration for SGD optimizer.
62+
63+
The attributes for this class matches the arguments of
64+
`tf.keras.optimizer.experimental.SGD`.
65+
66+
Attributes:
67+
name: name of the optimizer.
68+
nesterov: nesterov for SGD optimizer.
69+
momentum: momentum for SGD optimizer.
70+
"""
71+
name: str = "SGD"
72+
nesterov: bool = False
73+
momentum: float = 0.0
74+
jit_compile: bool = False
75+
76+
5777
@dataclasses.dataclass
5878
class RMSPropConfig(BaseOptimizerConfig):
5979
"""Configuration for RMSProp optimizer.

official/modeling/optimization/optimizer_factory.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import gin
1919
import tensorflow as tf
2020
import tensorflow_addons.optimizers as tfa_optimizers
21-
2221
from official.modeling.optimization import slide_optimizer
2322
from official.modeling.optimization import adafactor_optimizer
2423
from official.modeling.optimization import ema_optimizer
@@ -29,6 +28,7 @@
2928

3029
OPTIMIZERS_CLS = {
3130
'sgd': tf.keras.optimizers.SGD,
31+
'sgd_experimental': tf.keras.optimizers.experimental.SGD,
3232
'adam': tf.keras.optimizers.Adam,
3333
'adamw': nlp_optimization.AdamWeightDecay,
3434
'lamb': tfa_optimizers.LAMB,
@@ -178,7 +178,8 @@ def build_optimizer(
178178
takes an optimizer and returns an optimizer.
179179
180180
Returns:
181-
tf.keras.optimizers.Optimizer instance.
181+
`tf.keras.optimizers.Optimizer` or
182+
`tf.keras.optimizers.experimental.Optimizer` instance.
182183
"""
183184

184185
optimizer_dict = self._optimizer_config.as_dict()
@@ -201,8 +202,10 @@ def build_optimizer(
201202
optimizer, **self._ema_config.as_dict())
202203
if postprocessor:
203204
optimizer = postprocessor(optimizer)
204-
assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
205-
'OptimizerFactory.build_optimizer returning a non-optimizer object: '
205+
assert isinstance(
206+
optimizer, (tf.keras.optimizers.Optimizer,
207+
tf.keras.optimizers.experimental.Optimizer)
208+
), ('OptimizerFactory.build_optimizer returning a non-optimizer object: '
206209
'{}'.format(optimizer))
207210

208211
return optimizer

0 commit comments

Comments
 (0)