Skip to content

Commit d468f8e

Browse files
No public description
PiperOrigin-RevId: 730886312
1 parent 26b227f commit d468f8e

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

official/modeling/multitask/train_lib.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Multitask training driver library."""
1616
# pytype: disable=attribute-error
1717
import os
18-
from typing import Any, List, Mapping, Optional, Tuple, Union
18+
from typing import Any, List, Mapping, Optional, Tuple, Union, Callable
1919
from absl import logging
2020
import orbit
2121
import tensorflow as tf, tf_keras
@@ -157,6 +157,25 @@ def timeout_fn():
157157
return model
158158

159159

160+
TrainActionsFactoryType = Callable[
161+
[
162+
configs.MultiEvalExperimentConfig,
163+
orbit.StandardTrainer,
164+
str,
165+
tf.train.CheckpointManager,
166+
],
167+
List[orbit.Action],
168+
]
169+
EvalActionsFactoryType = Callable[
170+
[
171+
configs.MultiEvalExperimentConfig,
172+
orbit.AbstractEvaluator,
173+
str,
174+
],
175+
List[orbit.Action],
176+
]
177+
178+
160179
def run_experiment_with_multitask_eval(
161180
*,
162181
distribution_strategy: tf.distribute.Strategy,
@@ -171,6 +190,8 @@ def run_experiment_with_multitask_eval(
171190
eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
172191
best_ckpt_exporter_creator: Optional[Any] = train_utils
173192
.maybe_create_best_ckpt_exporter,
193+
train_actions_factory: Optional[TrainActionsFactoryType] = None,
194+
eval_actions_factory: Optional[EvalActionsFactoryType] = None,
174195
) -> Tuple[Any, Any]:
175196
"""Runs train/eval configured by the experiment params.
176197
@@ -193,6 +214,8 @@ def run_experiment_with_multitask_eval(
193214
will be created internally for TensorBoard summaries by default from the
194215
`eval_summary_dir`.
195216
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
217+
train_actions_factory: Optional factory function to create train actions.
218+
eval_actions_factory: Optional factory function to create eval actions.
196219
197220
Returns:
198221
model: `tf_keras.Model` instance.
@@ -214,7 +237,6 @@ def run_experiment_with_multitask_eval(
214237

215238
# Build the model or fetch the pre-cached one (which could be either
216239
# multi-task model or single task model).
217-
model = None
218240
if trainer is None:
219241
if isinstance(train_task, multitask.MultiTask):
220242
model = train_task.build_multitask_model()
@@ -254,6 +276,23 @@ def run_experiment_with_multitask_eval(
254276
checkpoint_interval=params.trainer.checkpoint_interval,
255277
init_fn=trainer.initialize if trainer else None)
256278

279+
if trainer and train_actions_factory:
280+
# pytype: disable=wrong-keyword-args
281+
train_actions = train_actions_factory(
282+
params=params,
283+
trainer=trainer,
284+
model_dir=model_dir,
285+
checkpoint_manager=checkpoint_manager,
286+
)
287+
# pytype: enable=wrong-keyword-args
288+
else:
289+
train_actions = None
290+
291+
if evaluator and eval_actions_factory:
292+
eval_actions = eval_actions_factory(params, evaluator, model_dir)
293+
else:
294+
eval_actions = None
295+
257296
controller = orbit.Controller(
258297
strategy=distribution_strategy,
259298
trainer=trainer,
@@ -266,7 +305,10 @@ def run_experiment_with_multitask_eval(
266305
(save_summary) else None,
267306
eval_summary_manager=eval_summary_manager,
268307
summary_interval=params.trainer.summary_interval if
269-
(save_summary) else None)
308+
(save_summary) else None,
309+
train_actions=train_actions,
310+
eval_actions=eval_actions,
311+
)
270312

271313
logging.info('Starts to execute mode: %s', mode)
272314
with distribution_strategy.scope():

0 commit comments

Comments
 (0)