1515"""Multitask training driver library."""
1616# pytype: disable=attribute-error
1717import os
18- from typing import Any , List , Mapping , Optional , Tuple , Union
18+ from typing import Any , List , Mapping , Optional , Tuple , Union , Callable
1919from absl import logging
2020import orbit
2121import tensorflow as tf , tf_keras
@@ -157,6 +157,25 @@ def timeout_fn():
157157 return model
158158
159159
160+ TrainActionsFactoryType = Callable [
161+ [
162+ configs .MultiEvalExperimentConfig ,
163+ orbit .StandardTrainer ,
164+ str ,
165+ tf .train .CheckpointManager ,
166+ ],
167+ List [orbit .Action ],
168+ ]
169+ EvalActionsFactoryType = Callable [
170+ [
171+ configs .MultiEvalExperimentConfig ,
172+ orbit .AbstractEvaluator ,
173+ str ,
174+ ],
175+ List [orbit .Action ],
176+ ]
177+
178+
160179def run_experiment_with_multitask_eval (
161180 * ,
162181 distribution_strategy : tf .distribute .Strategy ,
@@ -171,6 +190,8 @@ def run_experiment_with_multitask_eval(
171190 eval_summary_manager : Optional [orbit .utils .SummaryManagerInterface ] = None ,
172191 best_ckpt_exporter_creator : Optional [Any ] = train_utils
173192 .maybe_create_best_ckpt_exporter ,
193+ train_actions_factory : Optional [TrainActionsFactoryType ] = None ,
194+ eval_actions_factory : Optional [EvalActionsFactoryType ] = None ,
174195) -> Tuple [Any , Any ]:
175196 """Runs train/eval configured by the experiment params.
176197
@@ -193,6 +214,8 @@ def run_experiment_with_multitask_eval(
193214 will be created internally for TensorBoard summaries by default from the
194215 `eval_summary_dir`.
195216 best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
217+ train_actions_factory: Optional factory function to create train actions.
218+ eval_actions_factory: Optional factory function to create eval actions.
196219
197220 Returns:
198221 model: `tf_keras.Model` instance.
@@ -214,7 +237,6 @@ def run_experiment_with_multitask_eval(
214237
215238 # Build the model or fetch the pre-cached one (which could be either
216239 # multi-task model or single task model).
217- model = None
218240 if trainer is None :
219241 if isinstance (train_task , multitask .MultiTask ):
220242 model = train_task .build_multitask_model ()
@@ -254,6 +276,23 @@ def run_experiment_with_multitask_eval(
254276 checkpoint_interval = params .trainer .checkpoint_interval ,
255277 init_fn = trainer .initialize if trainer else None )
256278
279+ if trainer and train_actions_factory :
280+ # pytype: disable=wrong-keyword-args
281+ train_actions = train_actions_factory (
282+ params = params ,
283+ trainer = trainer ,
284+ model_dir = model_dir ,
285+ checkpoint_manager = checkpoint_manager ,
286+ )
287+ # pytype: enable=wrong-keyword-args
288+ else :
289+ train_actions = None
290+
291+ if evaluator and eval_actions_factory :
292+ eval_actions = eval_actions_factory (params , evaluator , model_dir )
293+ else :
294+ eval_actions = None
295+
257296 controller = orbit .Controller (
258297 strategy = distribution_strategy ,
259298 trainer = trainer ,
@@ -266,7 +305,10 @@ def run_experiment_with_multitask_eval(
266305 (save_summary ) else None ,
267306 eval_summary_manager = eval_summary_manager ,
268307 summary_interval = params .trainer .summary_interval if
269- (save_summary ) else None )
308+ (save_summary ) else None ,
309+ train_actions = train_actions ,
310+ eval_actions = eval_actions ,
311+ )
270312
271313 logging .info ('Starts to execute mode: %s' , mode )
272314 with distribution_strategy .scope ():
0 commit comments