Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 8532aba

Browse files
authored
RFC: TFX Generic Trainer
1 parent 57e8fcb commit 8532aba

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# TFX Generic Trainer
2+
3+
| Status | Proposed |
4+
| :------------ | :-------------------------------------------------------- |
5+
| **Author(s)** | Jiayi Zhao ([email protected]) |
6+
| **Sponsor** | Konstantinos Katsiapis ([email protected]), Zhitao Li |
7+
: : ([email protected]), Karmel Allison ([email protected]) :
8+
| **Updated** | 2020-01-17 |
9+
10+
## Objective
11+
12+
### Goal
13+
14+
* Support any TensorFlow Training loop in TFX Trainer in addition to
15+
tf.estimator, primarily focused on native Keras model.
16+
17+
### Non Goal
18+
19+
* Natively support multi-worker distributed training by the system.
20+
* Non-TF training that generates savedmodel.
21+
22+
## Background and Motivation
23+
24+
In current TFX Trainer component, only tf.estimator is supported for training
25+
and generating models. User provides a module file which contains a
26+
`trainer_fn`, trainer will call the function to get the estimator model and
27+
related spec for training, and generate a saved model by
28+
`tf.estimator.train_and_evaluate`.
29+
30+
[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level
31+
API for building and training models. It’s currently supported in TFX by using
32+
`tf.keras.estimator.model_to_estimator` in module file. User can create keras
33+
model in their `trainer_fn` but need to convert it to estimator for return (for
34+
example,
35+
[cifar10](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/cifar10/cifar10_utils.py)).
36+
37+
This doc will focus on native Keras support (without model_to_estimator) in TFX.
38+
We propose changing the user facing API to be more generic so that users can do
39+
(single node) native Keras model training within TFX.
40+
41+
## User Benefit
42+
43+
* Allows non estimator based training, especially Keras as TensorFlow is
44+
establishing Keras as the
45+
[Standardized high-level API](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a).
46+
* Allows
47+
[custom training](https://www.tensorflow.org/tutorials/customization/custom_training)
48+
for customization of training loop.
49+
50+
## Detailed Design
51+
52+
Below shows the pseudo code for current TFX Trainer’s executor:
53+
54+
```python
55+
class Executor(base_executor.BaseExecutor):
56+
57+
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
58+
output_dict: Dict[Text, List[types.Artifact]],
59+
exec_properties: Dict[Text, Any]) -> None:
60+
"""Uses a user-supplied tf.estimator to train a tf model locally."""
61+
trainer_fn = self._GetFn(exec_properties) # load from module file
62+
trainer_fn_args = self._GetFnArgs(
63+
input_dict, output_dict, exec_properties)
64+
65+
training_spec = trainer_fn(trainer_fn_args)
66+
tf.estimator.train_and_evaluate(training_spec['estimator'], ...)
67+
# For TFMA (downstream evaluator and model validator component).
68+
tfma.export.export_eval_savedmodel(training_spec['estimator'], ...)
69+
```
70+
71+
And the user supplied module file contains a function called `trainer_fn` which
72+
returns an estimator:
73+
74+
```python
75+
def _build_keras_model() -> tf.keras.Model:
76+
model = keras.XXX
77+
model.compile(...)
78+
return model
79+
80+
def trainer_fn(
81+
trainer_fn_args: trainer.executor.TrainerFnArgs) -> Dict[Text, Any]:
82+
"""Build the estimator using the high level API.
83+
84+
Args:
85+
trainer_fn_args: Holds args used to train the model as name/value pairs.
86+
87+
Returns:
88+
A dict of the following:
89+
- estimator: The estimator that will be used for training and eval.
90+
- train_spec: Spec for training.
91+
- eval_spec: Spec for eval.
92+
- eval_input_receiver_fn: Input function for eval.
93+
"""
94+
...
95+
96+
estimator = tf.keras.estimator.model_to_estimator(
97+
keras_model=_build_keras_model(), ...)
98+
99+
return {
100+
'estimator': estimator,
101+
'train_spec': ...,
102+
'eval_spec': ...,
103+
'eval_input_receiver_fn': ...
104+
}
105+
106+
```
107+
108+
We propose that in generic trainer's module file, user not only need to provide
109+
the model, but also control how the model is trained (`train_and_evaluate` for
110+
estimator and `model.fit` for keras will be in user module file instead of in
111+
executor), thus executor can be generic to model, and users can customize the
112+
[training loop](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#training_loop).
113+
The executor pseudo code would look like below:
114+
115+
```python
116+
class Executor(base_executor.BaseExecutor):
117+
118+
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
119+
output_dict: Dict[Text, List[types.Artifact]],
120+
exec_properties: Dict[Text, Any]) -> None:
121+
"""Train a user-supplied tf model."""
122+
run_fn = self._GetRunFn(exec_properties) # load from module file
123+
124+
# run_fn_args contains
125+
# 1. input train and eval data path.
126+
# 2. desired output model path for the trained savedmodel.
127+
# 3. training args, e.g., train/eval steps.
128+
# 4. optional base model.
129+
# 5. optional tuning result (kerastuner.HyperParameters config).
130+
# 6. optional custom config for passing params from component.
131+
run_fn_args = self._GetRunFnArgs(
132+
input_dict, output_dict, exec_properties)
133+
134+
run_fn(run_fn_args)
135+
# Validates the existence of run_fn's output savedmodel.
136+
...
137+
```
138+
139+
In module file, user needs to provide `run_fn` instead of previous `trainer_fn`.
140+
The `trainer_fn` was responsible for creating the model, in addition to that,
141+
`run_fn` also needs to handle training part and output the trained model to a
142+
desired location given by run args:
143+
144+
```python
145+
def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
146+
"""Build the TF model and train it."""
147+
model = _build_keras_model()
148+
model.fit(...)
149+
model.save(...)
150+
```
151+
152+
In generic trainer, executor is mainly for handling the
153+
[artifact](https://github.com/tensorflow/tfx/blob/r0.21/docs/guide/index.md#artifacts)
154+
(a unit of data that is passed between components), all model related logic is
155+
user supplied.
156+
157+
A separate GenericExecutor will be created, and the existing trainer executor
158+
will be sunsetted. We plan to keep estimator based executor for one more version
159+
and then deprecate it.
160+
161+
### How to convert current estimator based module file
162+
163+
To convert the current estimator based module file (e.g.,
164+
[iris](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/iris/iris_utils.py))
165+
for generic trainer, simply add a run_fn that calls the trainer_fn and train the
166+
returned model (code that used to be in the trainer.executor.Do).
167+
168+
```python
169+
def run_fn(fn_args: executor.TrainerFnArgs):
170+
"""Train the model based on given args.
171+
172+
Args:
173+
fn_args: Holds args used to train the model as name/value pairs.
174+
"""
175+
schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())
176+
177+
# Reuse the trainer_fn.
178+
training_spec = trainer_fn(fn_args, schema)
179+
180+
# Train the model
181+
absl.logging.info('Training model.')
182+
tf.estimator.train_and_evaluate(training_spec['estimator'],
183+
training_spec['train_spec'],
184+
training_spec['eval_spec'])
185+
absl.logging.info('Training complete. Model written to %s',
186+
fn_args.serving_model_dir)
187+
188+
# Export an eval savedmodel for TFMA, note that for keras, eval savedmodel is
189+
# not needed as TFMA2 can use serving model for evaluation.
190+
absl.logging.info('Exporting eval_savedmodel for TFMA.')
191+
tfma.export.export_eval_savedmodel(
192+
estimator=training_spec['estimator'],
193+
export_dir_base=fn_args.eval_model_dir,
194+
eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])
195+
196+
absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
197+
```
198+
199+
### tf.distribute.Strategy
200+
201+
Distribute strategy will be user module's responsibilty with the new generic
202+
trainer interface. To use it, user needs to modify the `run_fn()` in the module
203+
file, below shows the pseudo code example for single worker and multi-worker
204+
distribute strategy.
205+
206+
For single worker distribute strategy, you need to create an appropriate
207+
[tf.distribute.Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy),
208+
and move the creation and compiling of Keras model inside `strategy.scope`:
209+
210+
```python
211+
def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
212+
"""Build the TF model and train it."""
213+
mirrored_strategy = tf.distribute.MirroredStrategy()
214+
with mirrored_strategy.scope():
215+
model = _build_keras_model()
216+
model.fit(...)
217+
model.save(...)
218+
```
219+
220+
For multi-worker distribution strategy, the TFX Trainer does not have ability to
221+
spawn multi-worker cluster by
222+
[current executor](https://github.com/tensorflow/tfx/blob/r0.21/tfx/components/trainer/executor.py),
223+
hence not covered in the scope of this RFC. If the execution environment of an
224+
implementation of TFX Trainer has the ability to bring up the cluster of worker
225+
machines, and execute user funtion in the workers with correct
226+
[TF_CONFIG setup](https://www.tensorflow.org/guide/distributed_training#setting_up_tf_config_environment_variable),
227+
such as GCP AI Platform Training service via
228+
[extensions/google_cloud_ai_platform/trainer/executor.py](https://github.com/tensorflow/tfx/blob/r0.21/tfx/extensions/google_cloud_ai_platform/trainer/executor.py),
229+
the `run_fn()` would look like below:
230+
231+
```python
232+
def _is_chief() -> bool:
233+
"""Decide whether the current worker's role is chief."""
234+
# Check TF_CONFIG (set by TFX when bring up the worker) in execution env.
235+
...
236+
237+
def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
238+
"""Build the TF model and train it."""
239+
ps_strategy = tf.distribute.experimental.ParameterServerStrategy()
240+
with ps_strategy.scope():
241+
model = _build_keras_model()
242+
model.fit(...)
243+
if _is_chief():
244+
model.save(...)
245+
```
246+
247+
For details about `tf.distribute.Strategy`, please refer to
248+
[here](https://www.tensorflow.org/guide/distributed_training).
249+
250+
## Future work
251+
252+
* Examples for custom training loop.
253+
* Native support for multi-worker distribution.

0 commit comments

Comments
 (0)