@@ -42,19 +42,22 @@ def get_callbacks(model_checkpoint: bool = True,
42
42
callbacks = []
43
43
if model_checkpoint :
44
44
ckpt_full_path = os .path .join (model_dir , 'model.ckpt-{epoch:04d}' )
45
- callbacks .append (tf .keras .callbacks .ModelCheckpoint (
46
- ckpt_full_path , save_weights_only = True , verbose = 1 ))
45
+ callbacks .append (
46
+ tf .keras .callbacks .ModelCheckpoint (
47
+ ckpt_full_path , save_weights_only = True , verbose = 1 ))
47
48
if include_tensorboard :
48
- callbacks .append (CustomTensorBoard (
49
- log_dir = model_dir ,
50
- track_lr = track_lr ,
51
- initial_step = initial_step ,
52
- write_images = write_model_weights ))
49
+ callbacks .append (
50
+ CustomTensorBoard (
51
+ log_dir = model_dir ,
52
+ track_lr = track_lr ,
53
+ initial_step = initial_step ,
54
+ write_images = write_model_weights ))
53
55
if time_history :
54
- callbacks .append (keras_utils .TimeHistory (
55
- batch_size ,
56
- log_steps ,
57
- logdir = model_dir if include_tensorboard else None ))
56
+ callbacks .append (
57
+ keras_utils .TimeHistory (
58
+ batch_size ,
59
+ log_steps ,
60
+ logdir = model_dir if include_tensorboard else None ))
58
61
return callbacks
59
62
60
63
@@ -74,13 +77,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
74
77
- Global learning rate
75
78
76
79
Attributes:
77
- log_dir: the path of the directory where to save the log files to be
78
- parsed by TensorBoard.
80
+ log_dir: the path of the directory where to save the log files to be parsed
81
+ by TensorBoard.
79
82
track_lr: `bool`, whether or not to track the global learning rate.
80
83
initial_step: the initial step, used for preemption recovery.
81
- **kwargs: Additional arguments for backwards compatibility. Possible key
82
- is `period`.
84
+ **kwargs: Additional arguments for backwards compatibility. Possible key is
85
+ `period`.
83
86
"""
87
+
84
88
# TODO(b/146499062): track params, flops, log lr, l2 loss,
85
89
# classification loss
86
90
@@ -130,10 +134,8 @@ def _calculate_metrics(self) -> MutableMapping[str, Any]:
130
134
131
135
def _calculate_lr (self ) -> int :
132
136
"""Calculates the learning rate given the current step."""
133
- lr = self ._get_base_optimizer ().lr
134
- if callable (lr ):
135
- lr = lr (self .step )
136
- return get_scalar_from_tensor (lr )
137
+ return get_scalar_from_tensor (
138
+ self ._get_base_optimizer ()._decayed_lr (var_dtype = tf .float32 ))
137
139
138
140
def _get_base_optimizer (self ) -> tf .keras .optimizers .Optimizer :
139
141
"""Get the base optimizer used by the current model."""
0 commit comments