@@ -71,6 +71,7 @@ def __init__(
71
71
controller_cls = orbit .Controller ,
72
72
summary_manager : Optional [orbit .utils .SummaryManager ] = None ,
73
73
eval_summary_manager : Optional [orbit .utils .SummaryManager ] = None ,
74
+ enable_async_checkpointing : bool = False ,
74
75
):
75
76
"""Constructor.
76
77
@@ -94,6 +95,8 @@ def __init__(
94
95
summary manager.
95
96
eval_summary_manager: Instance of the eval summary manager to override
96
97
default eval summary manager.
98
+ enable_async_checkpointing: Optional boolean indicating whether to enable
99
+ async checkpoint saving.
97
100
"""
98
101
self .strategy = distribution_strategy or tf .distribute .get_strategy ()
99
102
self ._params = params
@@ -115,7 +118,8 @@ def __init__(
115
118
save_summary = save_summary ,
116
119
train_actions = train_actions ,
117
120
eval_actions = eval_actions ,
118
- controller_cls = controller_cls )
121
+ controller_cls = controller_cls ,
122
+ enable_async_checkpointing = enable_async_checkpointing )
119
123
120
124
@property
121
125
def params (self ) -> config_definitions .ExperimentConfig :
@@ -188,13 +192,16 @@ def _maybe_build_checkpoint_manager(
188
192
checkpoint_manager = None
189
193
return checkpoint_manager
190
194
191
- def _build_controller (self ,
192
- trainer ,
193
- evaluator ,
194
- save_summary : bool = True ,
195
- train_actions : Optional [List [orbit .Action ]] = None ,
196
- eval_actions : Optional [List [orbit .Action ]] = None ,
197
- controller_cls = orbit .Controller ) -> orbit .Controller :
195
+ def _build_controller (
196
+ self ,
197
+ trainer ,
198
+ evaluator ,
199
+ save_summary : bool = True ,
200
+ train_actions : Optional [List [orbit .Action ]] = None ,
201
+ eval_actions : Optional [List [orbit .Action ]] = None ,
202
+ controller_cls = orbit .Controller ,
203
+ enable_async_checkpointing : bool = False ,
204
+ ) -> orbit .Controller :
198
205
"""Builds a Orbit controler."""
199
206
train_actions = [] if not train_actions else train_actions
200
207
if trainer :
@@ -223,6 +230,7 @@ def _build_controller(self,
223
230
global_step = self .trainer .global_step ,
224
231
steps_per_loop = self .params .trainer .steps_per_loop ,
225
232
checkpoint_manager = self .checkpoint_manager ,
233
+ enable_async_checkpointing = enable_async_checkpointing ,
226
234
summary_dir = os .path .join (self .model_dir , 'train' )
227
235
if (save_summary )
228
236
else None ,
@@ -309,6 +317,7 @@ def run_experiment(
309
317
controller_cls = orbit .Controller ,
310
318
summary_manager : Optional [orbit .utils .SummaryManager ] = None ,
311
319
eval_summary_manager : Optional [orbit .utils .SummaryManager ] = None ,
320
+ enable_async_checkpointing : bool = False ,
312
321
) -> Tuple [tf .keras .Model , Mapping [str , Any ]]:
313
322
"""Runs train/eval configured by the experiment params.
314
323
@@ -332,6 +341,8 @@ def run_experiment(
332
341
manager.
333
342
eval_summary_manager: Instance of the eval summary manager to override
334
343
default eval summary manager.
344
+ enable_async_checkpointing: Optional boolean indicating whether to enable
345
+ async checkpoint saving.
335
346
336
347
Returns:
337
348
A 2-tuple of (model, eval_logs).
@@ -353,5 +364,6 @@ def run_experiment(
353
364
controller_cls = controller_cls ,
354
365
summary_manager = summary_manager ,
355
366
eval_summary_manager = eval_summary_manager ,
367
+ enable_async_checkpointing = enable_async_checkpointing ,
356
368
)
357
369
return runner .run ()
0 commit comments