Skip to content

Commit 3d10262

Browse files
Raise an explicit error if decay is set and new Keras optimizer is used.
PiperOrigin-RevId: 481980126
1 parent 6e2129f commit 3d10262

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

official/modeling/optimization/optimizer_factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ def build_optimizer(
236236
if use_legacy_optimizer:
237237
optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
238238
else:
239+
if 'decay' in optimizer_dict:
240+
raise ValueError(
241+
'`decay` is deprecated in new Keras optimizer, please reflect the '
242+
'decay logic in `lr` or set `use_legacy_optimizer=True` to use the '
243+
'legacy optimizer.')
239244
optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
240245

241246
if self._use_ema:

official/modeling/optimization/optimizer_factory_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def test_new_optimizers(self, optimizer_type):
6868
expected_optimizer_config['learning_rate'] = 0.1
6969

7070
opt_config = optimization_config.OptimizationConfig(params)
71+
if optimizer_type == 'sgd':
72+
# Delete unsupported arg `decay` from SGDConfig.
73+
delattr(opt_config.optimizer.sgd, 'decay')
7174
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
7275
lr = opt_factory.build_learning_rate()
7376
optimizer = opt_factory.build_optimizer(

0 commit comments

Comments
 (0)