Skip to content

Commit bcfba28

Browse files
saberkunallenwang28
authored andcommitted
Fix lr in callback
PiperOrigin-RevId: 304250237
1 parent 52c6556 commit bcfba28

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

official/vision/image_classification/callbacks.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,22 @@ def get_callbacks(model_checkpoint: bool = True,
4242
callbacks = []
4343
if model_checkpoint:
4444
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
45-
callbacks.append(tf.keras.callbacks.ModelCheckpoint(
46-
ckpt_full_path, save_weights_only=True, verbose=1))
45+
callbacks.append(
46+
tf.keras.callbacks.ModelCheckpoint(
47+
ckpt_full_path, save_weights_only=True, verbose=1))
4748
if include_tensorboard:
48-
callbacks.append(CustomTensorBoard(
49-
log_dir=model_dir,
50-
track_lr=track_lr,
51-
initial_step=initial_step,
52-
write_images=write_model_weights))
49+
callbacks.append(
50+
CustomTensorBoard(
51+
log_dir=model_dir,
52+
track_lr=track_lr,
53+
initial_step=initial_step,
54+
write_images=write_model_weights))
5355
if time_history:
54-
callbacks.append(keras_utils.TimeHistory(
55-
batch_size,
56-
log_steps,
57-
logdir=model_dir if include_tensorboard else None))
56+
callbacks.append(
57+
keras_utils.TimeHistory(
58+
batch_size,
59+
log_steps,
60+
logdir=model_dir if include_tensorboard else None))
5861
return callbacks
5962

6063

@@ -74,13 +77,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
7477
- Global learning rate
7578
7679
Attributes:
77-
log_dir: the path of the directory where to save the log files to be
78-
parsed by TensorBoard.
80+
log_dir: the path of the directory where to save the log files to be parsed
81+
by TensorBoard.
7982
track_lr: `bool`, whether or not to track the global learning rate.
8083
initial_step: the initial step, used for preemption recovery.
81-
**kwargs: Additional arguments for backwards compatibility. Possible key
82-
is `period`.
84+
**kwargs: Additional arguments for backwards compatibility. Possible key is
85+
`period`.
8386
"""
87+
8488
# TODO(b/146499062): track params, flops, log lr, l2 loss,
8589
# classification loss
8690

@@ -130,10 +134,8 @@ def _calculate_metrics(self) -> MutableMapping[str, Any]:
130134

131135
def _calculate_lr(self) -> int:
132136
"""Calculates the learning rate given the current step."""
133-
lr = self._get_base_optimizer().lr
134-
if callable(lr):
135-
lr = lr(self.step)
136-
return get_scalar_from_tensor(lr)
137+
return get_scalar_from_tensor(
138+
self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32))
137139

138140
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
139141
"""Get the base optimizer used by the current model."""

0 commit comments

Comments
 (0)