Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit d1c022d

Browse files
author
ematejska
authored
Merge pull request #236 from 1025KB/master
RFC: TFX Tuner Component
2 parents 7a2a5d5 + 16fbadf commit d1c022d

File tree

4 files changed

+383
-0
lines changed

4 files changed

+383
-0
lines changed
Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
# TFX Tuner Component
2+
3+
| Status | Proposed |
4+
| :------------ | :-------------------------------------------------------- |
5+
| **Author(s)** | Jiayi Zhao ([email protected]), Amy Wu ([email protected]) |
6+
| **Sponsor** | Zhitao Li ([email protected]), Tom O'Malley ([email protected]), Matthieu Monsch ([email protected]), Makoto Uchida ([email protected]), Goutham Bhat ([email protected]) |
7+
| **Updated** | 2020-04-20 |
8+
9+
## Objective
10+
11+
### Goal
12+
13+
* A new Tuner component in TFX for automated hyper-parameter tuning, which is
14+
based on abstractions from
15+
[KerasTuner library](https://github.com/keras-team/keras-tuner), in order to
16+
reuse abstractions and algorithms from latter.
17+
18+
### Non Goal
19+
20+
* Natively support multi-worker tuning by the system. As TFX doesn't have
21+
ability to manage multi-worker clusters, running multiple trials in parallel
22+
(parallel tuning) and running each trial in distributed env (distributed
23+
training) are not supported natively. Parallel tuning may instead be
24+
realized by a particular implementation of TFX Tuner (custom Executor),
25+
e.g., in Google Cloud environment.
26+
* Implementation of custom tuner for
27+
[KerasTuner library](https://github.com/keras-team/keras-tuner) is out of
28+
scope of this design discussion, e.g., a built-in EstimatorTuner support.
29+
However, user project can still implement a tuner that inherits from
30+
[`kerastuner.BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py)
31+
and provide it to the proposed TFX Tuner component.
32+
33+
## Background and Motivation
34+
35+
A hyperparameter is a parameter whose value is used to control the learning
36+
process of a model or the model itself (e.g., layers and number of nodes). By
37+
contrast, the values of other parameters (typically node weights) are learned.
38+
39+
Hyperparameter optimization is a critical part of many machine learning
40+
pipelines. Thus we propose a new TFX component, with the given search space
41+
which specifies the hyperparameter configuration (name, type, range etc.). TFX
42+
will optimize the hyperparameters based on the tuning algorithm.
43+
44+
## User Benefit
45+
46+
This document proposes a built-in TFX Tuner component, which works seamlessly
47+
with Trainer and other TFX components. As the Tuner component will utilize the
48+
[KerasTuner library](https://github.com/keras-team/keras-tuner), all supported
49+
tuning methods will be available to TFX, including custom implementation of
50+
KerasTuner.
51+
52+
## Design Proposal
53+
54+
TFX Tuner component will be built with the
55+
[KerasTuner library](https://github.com/keras-team/keras-tuner). In the
56+
following sections, we will first briefly go over the KerasTuner library and
57+
several concepts in hyperparameter optimization. Then we will focus on our Tuner
58+
component interface and how we utilize the KerasTuner library. After that, we
59+
will discuss parallel tuning and our plan on Google Cloud integration.
60+
61+
### KerasTuner Library
62+
63+
The following graph shows a typical workflow of hyperparameter tuning under the
64+
KerasTuner framework:
65+
66+
<div style="text-align:center"><img src='20200420-tfx-tuner-component/workflow.png', width='600'></div>
67+
68+
Given the user provided model which accepts a hyperparameter container, tuner
69+
can search optimization through trials created by the tuning algortihm. For each
70+
trial, values within search spaces will be assigned to hyperparameter
71+
containers, and the user model will be trained with these hyperparameter values
72+
and evaluated based on the objective provided to the tuner. The evaluation
73+
results will be reported back to tuner and the tuning algorithm will decide the
74+
hyperparameter values for the next trial. After reaching certain conditions,
75+
e.g., max trials, the tuner will stop iteration and return the optimal
76+
hyperparameters.
77+
78+
KerasTuner library provides above tuning functionality, here are some
79+
abstractions in KerasTuner:
80+
81+
* [`HyperParameters`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/hyperparameters.py):
82+
Hyperparameter container for both search space, and current values.
83+
* [`Oracle`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/oracle.py):
84+
Implementation of a hyperparameter tuning algorithm, e.g., random search,
85+
including state management of the algorithm’s progress.
86+
* [`Trial`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/trial.py):
87+
Provided by the Oracle, contains information about Hyperparameter values for
88+
the current iteration.
89+
* [`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py):
90+
a base tuner interface for above tuning workflow, responsible for the
91+
iteration of trial execution:
92+
* Generates Trial using Oracle.
93+
* Trains user model with the HyperParameters in the current Trial.
94+
* Evaluates metrics and reports back to Oracle for next Trial.
95+
* [`Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py):
96+
An implementation of BaseTuner, for Keras model tuning.
97+
98+
Note: Other than the Tuner, abstractions defined by `HyperParameters`, `Oracle`,
99+
`Trial` and `BaseTuner` are not restricted to Keras models, although the library
100+
is called KerasTuner.
101+
102+
For more details and code examples, please refer to
103+
[here](https://github.com/keras-team/keras-tuner).
104+
105+
### Component Interface
106+
107+
Tuner component takes raw or transformed examples as input, along with schema or
108+
transform_graph for the feature specification, and outputs the hyperparameter
109+
tuning results, below shows the specification of Tuner component:
110+
111+
```python
112+
class TunerSpec(ComponentSpec):
113+
"""ComponentSpec for TFX Tuner Component."""
114+
115+
PARAMETERS = {
116+
# Specify a python module file which contains a UDF `tuner_fn`.
117+
'module_file': ExecutionParameter(type=(str, Text), optional=True),
118+
# Specify the steps for the training stage of each trial’s execution.
119+
'train_args': ExecutionParameter(type=trainer_pb2.TrainArgs),
120+
'eval_args': ExecutionParameter(type=trainer_pb2.EvalArgs),
121+
}
122+
123+
INPUTS = {
124+
'examples': ChannelParameter(type=standard_artifacts.Examples),
125+
'schema': ChannelParameter(
126+
type=standard_artifacts.Schema, optional=True),
127+
'transform_graph':
128+
ChannelParameter(
129+
type=standard_artifacts.TransformGraph, optional=True),
130+
}
131+
132+
OUTPUTS = {
133+
'best_hyperparameters':
134+
ChannelParameter(type=standard_artifacts.HyperParameters),
135+
}
136+
```
137+
138+
Trainer has an optional hyperparameters input; tuning result can be fed into it
139+
so that Trainer can utilize best hyperparameters to construct the model. Below
140+
shows an example about how tuner and trainer are chained in the pipeline:
141+
142+
```python
143+
# TrainerSpec:
144+
INPUTS = {
145+
...
146+
'hyperparameters':
147+
ChannelParameter(
148+
type=standard_artifacts.HyperParameters, optional=True),
149+
}
150+
151+
# Pipeline DSL Example:
152+
tuner = Tuner(
153+
examples=example_gen.outputs['examples'],
154+
schema=schema_gen.outputs['schema'],
155+
module_file=module_file,
156+
train_args=trainer_pb2.TrainArgs(num_steps=1000),
157+
eval_args=trainer_pb2.EvalArgs(num_steps=500))
158+
159+
trainer = Trainer(
160+
module_file=module_file,
161+
examples=example_gen.outputs['examples'],
162+
schema=schema_gen.outputs['schema'],
163+
hyperparameters=tuner.outputs['best_hyperparameters'],
164+
train_args=trainer_pb2.TrainArgs(num_steps=10000),
165+
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
166+
```
167+
168+
For Trainer, users need to define model code and training logic
169+
([Generic Trainer](https://github.com/tensorflow/tfx/blob/r0.21.2/docs/guide/trainer.md#generic-trainer))
170+
in the module_file. For Tuner, in addition to model code, users also need to
171+
define hyperparameters, search space and a tuning algorithm in the module_file.
172+
A `tuner_fn` with the following signature is required for Tuner:
173+
174+
```python
175+
from kerastuner.engine import base_tuner
176+
import tensorflow as tf
177+
from tfx.components.trainer.executor import TrainerFnArgs
178+
179+
# Current TrainerFnArgs will be renamed to FnArgs as a util class.
180+
FnArgs = TrainerFnArgs
181+
TunerFnResult = NamedTuple('TunerFnResult',
182+
[('tuner', base_tuner.BaseTuner),
183+
('fit_kwargs', Dict[Text, Any])])
184+
185+
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
186+
"""Build the tuner using the KerasTuner API.
187+
188+
Args:
189+
fn_args: Holds args as name/value pairs.
190+
working_dir: working dir for tuning. Automatically set by Executor.
191+
train_files: List of file paths containing training tf.Example data.
192+
eval_files: List of file paths containing eval tf.Example data.
193+
train_steps: number of train steps.
194+
eval_steps: number of eval steps.
195+
schema: optional schema file of the input data.
196+
transform_graph: optional transform graph produced by TFT.
197+
198+
Returns:
199+
A namedtuple contains the following:
200+
- tuner: A BaseTuner that will be used for tuning.
201+
- fit_kwargs: Args to pass to tuner’s run_trial function for fitting the
202+
model , e.g., the training and validation dataset. Required
203+
args depend on the above tuner’s implementation.
204+
"""
205+
```
206+
207+
The TunerFnResult returned by the above tuner_fn contains an instance that
208+
implements the
209+
[`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py)
210+
interface, that’s the contract required by Tuner for tuning. The model code,
211+
hyperparameters, search space and tuning algorithm are hidden under the
212+
BaseTuner abstraction so the Tuner itself is generic and agnostic to the model
213+
framework and tuning logic. Below shows an example module file with Keras model:
214+
215+
```python
216+
import kerastuner
217+
import tensorflow as tf
218+
...
219+
220+
def _input_fn(file_pattern: Text, ...) -> tf.data.Dataset:
221+
...
222+
223+
# Model code for Trainer and Tuner.
224+
def _build_keras_model(hp: kerastuner.HyperParameters) -> tf.keras.Model:
225+
...
226+
for _ in range(hp.get('num_layers')):
227+
...
228+
...
229+
model = tf.keras.Model(...)
230+
model.compile(
231+
optimizer=tf.keras.optimizers.Adam(hp.get('learning_rate')),
232+
loss='sparse_categorical_crossentropy',
233+
metrics=[tf.keras.metrics.Accuracy()])
234+
return model
235+
236+
# This will be called by TFX Tuner.
237+
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
238+
hp = kerastuner.HyperParameters()
239+
# Defines search space.
240+
hp.Choice('learning_rate', [1e-1, 1e-3])
241+
hp.Int('num_layers', 1, 5)
242+
243+
# RandomSearch is a subclass of Keras model Tuner.
244+
tuner = kerastuner.RandomSearch(
245+
_build_keras_model,
246+
max_trials=5,
247+
hyperparameters=hp,
248+
allow_new_entries=False,
249+
objective='val_accuracy',
250+
directory=fn_args.working_dir,
251+
project_name='test')
252+
253+
train_dataset=_input_fn(fn_args.train_files, ...)
254+
eval_dataset=_input_fn(fn_args.eval_files, ...)
255+
256+
return TunerFnResult(
257+
tuner=tuner,
258+
fit_kwargs={'x': train_dataset,
259+
'validation_data': eval_dataset,
260+
'steps_per_epoch': fn_args.train_steps,
261+
'validation_steps': fn_args.eval_steps})
262+
263+
# This will be called by TFX Generic Trainer.
264+
def run_fn(fn_args: FnArgs) -> None:
265+
hp = kerastuner.HyperParameters.from_config(fn_args.hyperparameters)
266+
model = _build_keras_model(hp)
267+
model.fit(...)
268+
model.save(...)
269+
```
270+
271+
In Tuner’s executor, `tuner_fn` will be called with information resolved from
272+
component inputs, then we call the `search` function of the returned tuner with
273+
`fit_kwargs` to launch trials for tuning, and finally emit the best trial’s
274+
hyperparameters:
275+
276+
```python
277+
# Executor of Tuner Component:
278+
class Executor(base_executor.BaseExecutor):
279+
280+
def Do(self,
281+
input_dict: Dict[Text, List[types.Artifact]],
282+
output_dict: Dict[Text, List[types.Artifact]],
283+
exec_properties: Dict[Text, Any]) -> None:
284+
...
285+
tuner_spec = tuner_fn(self._create_fn_args(input_dict, exec_properties))
286+
tuner_spec.tuner.search(**tuner_spec.fit_kwargs)
287+
# Output file contains json format string of hyperparameters.get_config().
288+
self._emit_best_hyperparameters(
289+
output_dict, tuner_spec.tuner.get_best_hyperparameters()[0])
290+
```
291+
292+
### Parallel Tuning
293+
294+
In parallel tuning, multiple trials are executed in parallel. In this section,
295+
we will discuss how distribution works for KerasTuner library and the status of
296+
TFX.
297+
298+
In the `search` function of tuner, trials will be run in sequence instead of in
299+
parallel. To support parallel tuning, we need to launch multiple tuners (the
300+
tuner here refers to the one in KerasTuner library, not TFX Tuner component),
301+
and have an optimization service for managing the state of the tuning algorithm,
302+
with which oracle of each tuner communicates, and retrieves the trials for each
303+
tuner.
304+
305+
<div style="text-align:center"><img src='20200420-tfx-tuner-component/parallel_tuning.png', width='600'></div>
306+
307+
The above graph shows a parallel tuning of three tuners. Each tuner runs as a
308+
different worker, and it retrieves trials from its own oracle, which talks to
309+
optimization service. Trials of different tuners can run in parallel but trials
310+
within the same tuner will still execute in sequence. When launching tuners, the
311+
same identifier will be assigned to each oracle, thus the optimization service
312+
knows they are in the same tuning job group and will assign hyperparameter
313+
values for their trials based on the algorithm.
314+
315+
The number of parallel tuners can be passed to component by the `TuneArgs` as
316+
shown below:
317+
318+
```python
319+
# Args specific to tuning.
320+
message TuneArgs {
321+
# Number of trials to run in parallel.
322+
# Each trial will be trained and evaluated by separate worker jobs.
323+
int32 num_parallel_trials = 1;
324+
}
325+
326+
class TunerSpec(ComponentSpec):
327+
328+
PARAMETERS = {
329+
...
330+
'tune_args': ExecutionParameter(type=tuner_pb2.TuneArgs),
331+
}
332+
```
333+
334+
The KerasTuner library allows users to config
335+
[`tf.distribute.Strategy`](https://www.tensorflow.org/tutorials/distribute/kerass)
336+
if they are using
337+
[`kerastuner.Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py)
338+
class (or subclasses of it). In above parallel tuning, each trial (each model
339+
training) is executed in a single worker, as such only single machine strategy
340+
is allowed. To support multi-worker distributed training, we need to be able to
341+
execute the trial (training) on different workers.
342+
343+
At the time of writing, KerasTuner library can be used for parallel tuning with
344+
single machine `tf.distribute.Strategy`, e.g.,
345+
[`MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
346+
, multi-worker strategy (distributed training for trial) support is on the
347+
roadmap (note that cluster managing is not part of the library).
348+
349+
At the time of writing, TFX doesn’t have the ability to manage the multi-worker
350+
cluster and the centralized optimization service, so parallel tuning or
351+
distributed training is not supported natively in TFX (local or on-prem), but in
352+
the next section, we will discuss the integration for Google Cloud. Similar
353+
parallel tuning support can be built for other execution environments.
354+
355+
### Google Cloud Integration
356+
357+
In this section, we discuss the Tuner component with
358+
[Google Cloud AI Platform](https://cloud.google.com/ai-platform) (CAIP),
359+
specifically, an implementation of KerasTuner Oracle that talks to the
360+
[AI Platform Optimizer](https://cloud.google.com/ai-platform/optimizer/docs/overview)
361+
as the centralized optimization service, and a custom Tuner executor
362+
implementation that makes use of the Cloud Optimizer-based Oracle (symbol names
363+
are subject to change).
364+
365+
As mentioned above in the parallel tuning section, KerasTuner uses a centralized
366+
optimization service that manages states of a tuning study and trials. In
367+
addition to that, we will create a `CloudOracle` as a client to the AI Platform
368+
Optimizer service, and a `CloudTuner` which inherits from Keras
369+
[Tuner](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py).
370+
In the module file, users create the `tuner_fn` with `CloudTuner`, and then
371+
users configure the TFX Tuner component to use the a custom Tuner executor
372+
(`CloudExecutor`), which launches multiple `CloudTuner`s on a Google Cloud AI
373+
Platform Training job with possibly multiple worker machines running various
374+
trials concurrently. Below shows the workflow for in process tuning and Cloud
375+
tuning.
376+
377+
<div style="text-align:center"><img src='20200420-tfx-tuner-component/cloud.png', width='600'></div>
378+
379+
## Future work
380+
381+
* Native support for multi-worker parallel tuning.
382+
* Custom Tuner (inherits from BaseTuner) examples, e.g., for Estimator support
383+
or Keras custom training loop support.
80 KB
Loading
55 KB
Loading
22.2 KB
Loading

0 commit comments

Comments
 (0)