Skip to content

Commit 6e2a1d5

Browse files
Enable async checkpoint by default in Tensorflow model garden.
PiperOrigin-RevId: 532860554
1 parent 8f17df9 commit 6e2a1d5

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

official/core/train_lib.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
controller_cls=orbit.Controller,
7272
summary_manager: Optional[orbit.utils.SummaryManager] = None,
7373
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
74+
enable_async_checkpointing: bool = False,
7475
):
7576
"""Constructor.
7677
@@ -94,6 +95,8 @@ def __init__(
9495
summary manager.
9596
eval_summary_manager: Instance of the eval summary manager to override
9697
default eval summary manager.
98+
enable_async_checkpointing: Optional boolean indicating whether to enable
99+
async checkpoint saving.
97100
"""
98101
self.strategy = distribution_strategy or tf.distribute.get_strategy()
99102
self._params = params
@@ -115,7 +118,8 @@ def __init__(
115118
save_summary=save_summary,
116119
train_actions=train_actions,
117120
eval_actions=eval_actions,
118-
controller_cls=controller_cls)
121+
controller_cls=controller_cls,
122+
enable_async_checkpointing=enable_async_checkpointing)
119123

120124
@property
121125
def params(self) -> config_definitions.ExperimentConfig:
@@ -188,13 +192,16 @@ def _maybe_build_checkpoint_manager(
188192
checkpoint_manager = None
189193
return checkpoint_manager
190194

191-
def _build_controller(self,
192-
trainer,
193-
evaluator,
194-
save_summary: bool = True,
195-
train_actions: Optional[List[orbit.Action]] = None,
196-
eval_actions: Optional[List[orbit.Action]] = None,
197-
controller_cls=orbit.Controller) -> orbit.Controller:
195+
def _build_controller(
196+
self,
197+
trainer,
198+
evaluator,
199+
save_summary: bool = True,
200+
train_actions: Optional[List[orbit.Action]] = None,
201+
eval_actions: Optional[List[orbit.Action]] = None,
202+
controller_cls=orbit.Controller,
203+
enable_async_checkpointing: bool = False,
204+
) -> orbit.Controller:
198205
"""Builds a Orbit controler."""
199206
train_actions = [] if not train_actions else train_actions
200207
if trainer:
@@ -223,6 +230,7 @@ def _build_controller(self,
223230
global_step=self.trainer.global_step,
224231
steps_per_loop=self.params.trainer.steps_per_loop,
225232
checkpoint_manager=self.checkpoint_manager,
233+
enable_async_checkpointing=enable_async_checkpointing,
226234
summary_dir=os.path.join(self.model_dir, 'train')
227235
if (save_summary)
228236
else None,
@@ -309,6 +317,7 @@ def run_experiment(
309317
controller_cls=orbit.Controller,
310318
summary_manager: Optional[orbit.utils.SummaryManager] = None,
311319
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
320+
enable_async_checkpointing: bool = False,
312321
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
313322
"""Runs train/eval configured by the experiment params.
314323
@@ -332,6 +341,8 @@ def run_experiment(
332341
manager.
333342
eval_summary_manager: Instance of the eval summary manager to override
334343
default eval summary manager.
344+
enable_async_checkpointing: Optional boolean indicating whether to enable
345+
async checkpoint saving.
335346
336347
Returns:
337348
A 2-tuple of (model, eval_logs).
@@ -353,5 +364,6 @@ def run_experiment(
353364
controller_cls=controller_cls,
354365
summary_manager=summary_manager,
355366
eval_summary_manager=eval_summary_manager,
367+
enable_async_checkpointing=enable_async_checkpointing,
356368
)
357369
return runner.run()

official/nlp/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
default=None,
3939
help='The number of total training steps for the pretraining job.')
4040

41+
flags.DEFINE_bool(
42+
'enable_async_checkpointing',
43+
default=True,
44+
help='A boolean indicating whether to enable async checkpoint saving')
45+
4146

4247
def _run_experiment_with_preemption_recovery(params, model_dir):
4348
"""Runs experiment and tries to reconnect when encounting a preemption."""
@@ -62,7 +67,8 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
6267
task=task,
6368
mode=FLAGS.mode,
6469
params=params,
65-
model_dir=model_dir)
70+
model_dir=model_dir,
71+
enable_async_checkpointing=FLAGS.enable_async_checkpointing)
6672

6773
keep_training = False
6874
except tf.errors.OpError as e:

official/vision/train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232

3333
FLAGS = flags.FLAGS
3434

35+
flags.DEFINE_bool(
36+
'enable_async_checkpointing',
37+
default=True,
38+
help='A boolean indicating whether to enable async checkpoint saving')
39+
3540

3641
def _run_experiment_with_preemption_recovery(params, model_dir):
3742
"""Runs experiment and tries to reconnect when encounting a preemption."""
@@ -60,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
6065
eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
6166
params=params, model_dir=model_dir
6267
),
68+
enable_async_checkpointing=FLAGS.enable_async_checkpointing,
6369
)
6470

6571
keep_training = False

0 commit comments

Comments
 (0)