15
15
"""Multitask training driver library."""
16
16
# pytype: disable=attribute-error
17
17
import os
18
- from typing import Any , List , Optional , Tuple
18
+ from typing import Any , List , Mapping , Optional , Tuple , Union
19
19
from absl import logging
20
20
import orbit
21
21
import tensorflow as tf
@@ -44,8 +44,10 @@ def run_experiment(
44
44
mode : str ,
45
45
params : configs .MultiTaskExperimentConfig ,
46
46
model_dir : str ,
47
+ run_post_eval : bool = False ,
47
48
trainer : base_trainer .MultiTaskBaseTrainer = None
48
- ) -> base_model .MultiTaskBaseModel :
49
+ ) -> Union [base_model .MultiTaskBaseModel ,
50
+ Tuple [base_model .MultiTaskBaseModel , Mapping [Any , Any ]]]:
49
51
"""Runs train/eval configured by the experiment params.
50
52
51
53
Args:
@@ -56,6 +58,8 @@ def run_experiment(
56
58
or 'continuous_eval'.
57
59
params: ExperimentConfig instance.
58
60
model_dir: A 'str', a path to store model checkpoints and summaries.
61
+ run_post_eval: Whether to run post eval once after training, metrics logs
62
+ are returned.
59
63
trainer: (optional) A multi-task trainer to use. If none is provided, a
60
64
default one will be created based on `params`.
61
65
@@ -139,7 +143,11 @@ def timeout_fn():
139
143
else :
140
144
raise NotImplementedError ('The mode is not implemented: %s' % mode )
141
145
142
- return model
146
+ if run_post_eval :
147
+ return model , evaluator .evaluate (
148
+ tf .convert_to_tensor (params .trainer .validation_steps )) # pytype: disable=bad-return-type # typed-keras
149
+ else :
150
+ return model
143
151
144
152
145
153
def run_experiment_with_multitask_eval (
0 commit comments