|
| 1 | +# TFX Tuner Component |
| 2 | + |
| 3 | +| Status | Proposed | |
| 4 | +| :------------ | :-------------------------------------------------------- | |
| 5 | +| **Author(s) ** | Jiayi Zhao ( [email protected]), Amy Wu ( [email protected]) | |
| 6 | +| **Sponsor ** | Zhitao Li ( [email protected]), Tom O'Malley ( [email protected]), Matthieu Monsch ( [email protected]), Makoto Uchida ( [email protected]), Goutham Bhat ( [email protected]) | |
| 7 | +| **Updated** | 2020-04-20 | |
| 8 | + |
| 9 | +## Objective |
| 10 | + |
| 11 | +### Goal |
| 12 | + |
| 13 | +* A new Tuner component in TFX for automated hyper-parameter tuning, which is |
| 14 | + based on abstractions from |
| 15 | + [KerasTuner library](https://github.com/keras-team/keras-tuner), in order to |
| 16 | + reuse abstractions and algorithms from latter. |
| 17 | + |
| 18 | +### Non Goal |
| 19 | + |
| 20 | +* Natively support multi-worker tuning by the system. As TFX doesn't have |
| 21 | + ability to manage multi-worker clusters, running multiple trials in parallel |
| 22 | + (parallel tuning) and running each trial in distributed env (distributed |
| 23 | + training) are not supported natively. Parallel tuning may instead be |
| 24 | + realized by a particular implementation of TFX Tuner (custom Executor), |
| 25 | + e.g., in Google Cloud environment. |
| 26 | +* Implementation of custom tuner for |
| 27 | + [KerasTuner library](https://github.com/keras-team/keras-tuner) is out of |
| 28 | + scope of this design discussion, e.g., a built-in EstimatorTuner support. |
| 29 | + However, user project can still implement a tuner that inherits from |
| 30 | + [`kerastuner.BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py) |
| 31 | + and provide it to the proposed TFX Tuner component. |
| 32 | + |
| 33 | +## Background and Motivation |
| 34 | + |
| 35 | +A hyperparameter is a parameter whose value is used to control the learning |
| 36 | +process of a model or the model itself (e.g., layers and number of nodes). By |
| 37 | +contrast, the values of other parameters (typically node weights) are learned. |
| 38 | + |
| 39 | +Hyperparameter optimization is a critical part of many machine learning |
| 40 | +pipelines. Thus we propose a new TFX component, with the given search space |
| 41 | +which specifies the hyperparameter configuration (name, type, range etc.). TFX |
| 42 | +will optimize the hyperparameters based on the tuning algorithm. |
| 43 | + |
| 44 | +## User Benefit |
| 45 | + |
| 46 | +This document proposes a built-in TFX Tuner component, which works seamlessly |
| 47 | +with Trainer and other TFX components. As the Tuner component will utilize the |
| 48 | +[KerasTuner library](https://github.com/keras-team/keras-tuner), all supported |
| 49 | +tuning methods will be available to TFX, including custom implementation of |
| 50 | +KerasTuner. |
| 51 | + |
| 52 | +## Design Proposal |
| 53 | + |
| 54 | +TFX Tuner component will be built with the |
| 55 | +[KerasTuner library](https://github.com/keras-team/keras-tuner). In the |
| 56 | +following sections, we will first briefly go over the KerasTuner library and |
| 57 | +several concepts in hyperparameter optimization. Then we will focus on our Tuner |
| 58 | +component interface and how we utilize the KerasTuner library. After that, we |
| 59 | +will discuss parallel tuning and our plan on Google Cloud integration. |
| 60 | + |
| 61 | +### KerasTuner Library |
| 62 | + |
| 63 | +The following graph shows a typical workflow of hyperparameter tuning under the |
| 64 | +KerasTuner framework: |
| 65 | + |
| 66 | +<div style="text-align:center"><img src='20200420-tfx-tuner-component/workflow.png', width='600'></div> |
| 67 | + |
| 68 | +Given the user provided model which accepts a hyperparameter container, tuner |
| 69 | +can search optimization through trials created by the tuning algortihm. For each |
| 70 | +trial, values within search spaces will be assigned to hyperparameter |
| 71 | +containers, and the user model will be trained with these hyperparameter values |
| 72 | +and evaluated based on the objective provided to the tuner. The evaluation |
| 73 | +results will be reported back to tuner and the tuning algorithm will decide the |
| 74 | +hyperparameter values for the next trial. After reaching certain conditions, |
| 75 | +e.g., max trials, the tuner will stop iteration and return the optimal |
| 76 | +hyperparameters. |
| 77 | + |
| 78 | +KerasTuner library provides above tuning functionality, here are some |
| 79 | +abstractions in KerasTuner: |
| 80 | + |
| 81 | +* [`HyperParameters`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/hyperparameters.py): |
| 82 | + Hyperparameter container for both search space, and current values. |
| 83 | +* [`Oracle`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/oracle.py): |
| 84 | + Implementation of a hyperparameter tuning algorithm, e.g., random search, |
| 85 | + including state management of the algorithm’s progress. |
| 86 | +* [`Trial`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/trial.py): |
| 87 | + Provided by the Oracle, contains information about Hyperparameter values for |
| 88 | + the current iteration. |
| 89 | +* [`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py): |
| 90 | + a base tuner interface for above tuning workflow, responsible for the |
| 91 | + iteration of trial execution: |
| 92 | + * Generates Trial using Oracle. |
| 93 | + * Trains user model with the HyperParameters in the current Trial. |
| 94 | + * Evaluates metrics and reports back to Oracle for next Trial. |
| 95 | +* [`Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py): |
| 96 | + An implementation of BaseTuner, for Keras model tuning. |
| 97 | + |
| 98 | +Note: Other than the Tuner, abstractions defined by `HyperParameters`, `Oracle`, |
| 99 | +`Trial` and `BaseTuner` are not restricted to Keras models, although the library |
| 100 | +is called KerasTuner. |
| 101 | + |
| 102 | +For more details and code examples, please refer to |
| 103 | +[here](https://github.com/keras-team/keras-tuner). |
| 104 | + |
| 105 | +### Component Interface |
| 106 | + |
| 107 | +Tuner component takes raw or transformed examples as input, along with schema or |
| 108 | +transform_graph for the feature specification, and outputs the hyperparameter |
| 109 | +tuning results, below shows the specification of Tuner component: |
| 110 | + |
| 111 | +```python |
| 112 | +class TunerSpec(ComponentSpec): |
| 113 | + """ComponentSpec for TFX Tuner Component.""" |
| 114 | + |
| 115 | + PARAMETERS = { |
| 116 | + # Specify a python module file which contains a UDF `tuner_fn`. |
| 117 | + 'module_file': ExecutionParameter(type=(str, Text), optional=True), |
| 118 | + # Specify the steps for the training stage of each trial’s execution. |
| 119 | + 'train_args': ExecutionParameter(type=trainer_pb2.TrainArgs), |
| 120 | + 'eval_args': ExecutionParameter(type=trainer_pb2.EvalArgs), |
| 121 | + } |
| 122 | + |
| 123 | + INPUTS = { |
| 124 | + 'examples': ChannelParameter(type=standard_artifacts.Examples), |
| 125 | + 'schema': ChannelParameter( |
| 126 | + type=standard_artifacts.Schema, optional=True), |
| 127 | + 'transform_graph': |
| 128 | + ChannelParameter( |
| 129 | + type=standard_artifacts.TransformGraph, optional=True), |
| 130 | + } |
| 131 | + |
| 132 | + OUTPUTS = { |
| 133 | + 'best_hyperparameters': |
| 134 | + ChannelParameter(type=standard_artifacts.HyperParameters), |
| 135 | + } |
| 136 | +``` |
| 137 | + |
| 138 | +Trainer has an optional hyperparameters input; tuning result can be fed into it |
| 139 | +so that Trainer can utilize best hyperparameters to construct the model. Below |
| 140 | +shows an example about how tuner and trainer are chained in the pipeline: |
| 141 | + |
| 142 | +```python |
| 143 | +# TrainerSpec: |
| 144 | + INPUTS = { |
| 145 | + ... |
| 146 | + 'hyperparameters': |
| 147 | + ChannelParameter( |
| 148 | + type=standard_artifacts.HyperParameters, optional=True), |
| 149 | + } |
| 150 | + |
| 151 | +# Pipeline DSL Example: |
| 152 | + tuner = Tuner( |
| 153 | + examples=example_gen.outputs['examples'], |
| 154 | + schema=schema_gen.outputs['schema'], |
| 155 | + module_file=module_file, |
| 156 | + train_args=trainer_pb2.TrainArgs(num_steps=1000), |
| 157 | + eval_args=trainer_pb2.EvalArgs(num_steps=500)) |
| 158 | + |
| 159 | + trainer = Trainer( |
| 160 | + module_file=module_file, |
| 161 | + examples=example_gen.outputs['examples'], |
| 162 | + schema=schema_gen.outputs['schema'], |
| 163 | + hyperparameters=tuner.outputs['best_hyperparameters'], |
| 164 | + train_args=trainer_pb2.TrainArgs(num_steps=10000), |
| 165 | + eval_args=trainer_pb2.EvalArgs(num_steps=5000)) |
| 166 | +``` |
| 167 | + |
| 168 | +For Trainer, users need to define model code and training logic |
| 169 | +([Generic Trainer](https://github.com/tensorflow/tfx/blob/r0.21.2/docs/guide/trainer.md#generic-trainer)) |
| 170 | +in the module_file. For Tuner, in addition to model code, users also need to |
| 171 | +define hyperparameters, search space and a tuning algorithm in the module_file. |
| 172 | +A `tuner_fn` with the following signature is required for Tuner: |
| 173 | + |
| 174 | +```python |
| 175 | +from kerastuner.engine import base_tuner |
| 176 | +import tensorflow as tf |
| 177 | +from tfx.components.trainer.executor import TrainerFnArgs |
| 178 | + |
| 179 | +# Current TrainerFnArgs will be renamed to FnArgs as a util class. |
| 180 | +FnArgs = TrainerFnArgs |
| 181 | +TunerFnResult = NamedTuple('TunerFnResult', |
| 182 | + [('tuner', base_tuner.BaseTuner), |
| 183 | + ('fit_kwargs', Dict[Text, Any])]) |
| 184 | + |
| 185 | +def tuner_fn(fn_args: FnArgs) -> TunerFnResult: |
| 186 | + """Build the tuner using the KerasTuner API. |
| 187 | +
|
| 188 | + Args: |
| 189 | + fn_args: Holds args as name/value pairs. |
| 190 | + working_dir: working dir for tuning. Automatically set by Executor. |
| 191 | + train_files: List of file paths containing training tf.Example data. |
| 192 | + eval_files: List of file paths containing eval tf.Example data. |
| 193 | + train_steps: number of train steps. |
| 194 | + eval_steps: number of eval steps. |
| 195 | + schema: optional schema file of the input data. |
| 196 | + transform_graph: optional transform graph produced by TFT. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + A namedtuple contains the following: |
| 200 | + - tuner: A BaseTuner that will be used for tuning. |
| 201 | + - fit_kwargs: Args to pass to tuner’s run_trial function for fitting the |
| 202 | + model , e.g., the training and validation dataset. Required |
| 203 | + args depend on the above tuner’s implementation. |
| 204 | + """ |
| 205 | +``` |
| 206 | + |
| 207 | +The TunerFnResult returned by the above tuner_fn contains an instance that |
| 208 | +implements the |
| 209 | +[`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py) |
| 210 | +interface, that’s the contract required by Tuner for tuning. The model code, |
| 211 | +hyperparameters, search space and tuning algorithm are hidden under the |
| 212 | +BaseTuner abstraction so the Tuner itself is generic and agnostic to the model |
| 213 | +framework and tuning logic. Below shows an example module file with Keras model: |
| 214 | + |
| 215 | +```python |
| 216 | +import kerastuner |
| 217 | +import tensorflow as tf |
| 218 | +... |
| 219 | + |
| 220 | +def _input_fn(file_pattern: Text, ...) -> tf.data.Dataset: |
| 221 | + ... |
| 222 | + |
| 223 | +# Model code for Trainer and Tuner. |
| 224 | +def _build_keras_model(hp: kerastuner.HyperParameters) -> tf.keras.Model: |
| 225 | + ... |
| 226 | + for _ in range(hp.get('num_layers')): |
| 227 | + ... |
| 228 | + ... |
| 229 | + model = tf.keras.Model(...) |
| 230 | + model.compile( |
| 231 | + optimizer=tf.keras.optimizers.Adam(hp.get('learning_rate')), |
| 232 | + loss='sparse_categorical_crossentropy', |
| 233 | + metrics=[tf.keras.metrics.Accuracy()]) |
| 234 | + return model |
| 235 | + |
| 236 | +# This will be called by TFX Tuner. |
| 237 | +def tuner_fn(fn_args: FnArgs) -> TunerFnResult: |
| 238 | + hp = kerastuner.HyperParameters() |
| 239 | + # Defines search space. |
| 240 | + hp.Choice('learning_rate', [1e-1, 1e-3]) |
| 241 | + hp.Int('num_layers', 1, 5) |
| 242 | + |
| 243 | + # RandomSearch is a subclass of Keras model Tuner. |
| 244 | + tuner = kerastuner.RandomSearch( |
| 245 | + _build_keras_model, |
| 246 | + max_trials=5, |
| 247 | + hyperparameters=hp, |
| 248 | + allow_new_entries=False, |
| 249 | + objective='val_accuracy', |
| 250 | + directory=fn_args.working_dir, |
| 251 | + project_name='test') |
| 252 | + |
| 253 | + train_dataset=_input_fn(fn_args.train_files, ...) |
| 254 | + eval_dataset=_input_fn(fn_args.eval_files, ...) |
| 255 | + |
| 256 | + return TunerFnResult( |
| 257 | + tuner=tuner, |
| 258 | + fit_kwargs={'x': train_dataset, |
| 259 | + 'validation_data': eval_dataset, |
| 260 | + 'steps_per_epoch': fn_args.train_steps, |
| 261 | + 'validation_steps': fn_args.eval_steps}) |
| 262 | + |
| 263 | +# This will be called by TFX Generic Trainer. |
| 264 | +def run_fn(fn_args: FnArgs) -> None: |
| 265 | + hp = kerastuner.HyperParameters.from_config(fn_args.hyperparameters) |
| 266 | + model = _build_keras_model(hp) |
| 267 | + model.fit(...) |
| 268 | + model.save(...) |
| 269 | +``` |
| 270 | + |
| 271 | +In Tuner’s executor, `tuner_fn` will be called with information resolved from |
| 272 | +component inputs, then we call the `search` function of the returned tuner with |
| 273 | +`fit_kwargs` to launch trials for tuning, and finally emit the best trial’s |
| 274 | +hyperparameters: |
| 275 | + |
| 276 | +```python |
| 277 | +# Executor of Tuner Component: |
| 278 | +class Executor(base_executor.BaseExecutor): |
| 279 | + |
| 280 | + def Do(self, |
| 281 | + input_dict: Dict[Text, List[types.Artifact]], |
| 282 | + output_dict: Dict[Text, List[types.Artifact]], |
| 283 | + exec_properties: Dict[Text, Any]) -> None: |
| 284 | + ... |
| 285 | + tuner_spec = tuner_fn(self._create_fn_args(input_dict, exec_properties)) |
| 286 | + tuner_spec.tuner.search(**tuner_spec.fit_kwargs) |
| 287 | + # Output file contains json format string of hyperparameters.get_config(). |
| 288 | + self._emit_best_hyperparameters( |
| 289 | + output_dict, tuner_spec.tuner.get_best_hyperparameters()[0]) |
| 290 | +``` |
| 291 | + |
| 292 | +### Parallel Tuning |
| 293 | + |
| 294 | +In parallel tuning, multiple trials are executed in parallel. In this section, |
| 295 | +we will discuss how distribution works for KerasTuner library and the status of |
| 296 | +TFX. |
| 297 | + |
| 298 | +In the `search` function of tuner, trials will be run in sequence instead of in |
| 299 | +parallel. To support parallel tuning, we need to launch multiple tuners (the |
| 300 | +tuner here refers to the one in KerasTuner library, not TFX Tuner component), |
| 301 | +and have an optimization service for managing the state of the tuning algorithm, |
| 302 | +with which oracle of each tuner communicates, and retrieves the trials for each |
| 303 | +tuner. |
| 304 | + |
| 305 | +<div style="text-align:center"><img src='20200420-tfx-tuner-component/parallel_tuning.png', width='600'></div> |
| 306 | + |
| 307 | +The above graph shows a parallel tuning of three tuners. Each tuner runs as a |
| 308 | +different worker, and it retrieves trials from its own oracle, which talks to |
| 309 | +optimization service. Trials of different tuners can run in parallel but trials |
| 310 | +within the same tuner will still execute in sequence. When launching tuners, the |
| 311 | +same identifier will be assigned to each oracle, thus the optimization service |
| 312 | +knows they are in the same tuning job group and will assign hyperparameter |
| 313 | +values for their trials based on the algorithm. |
| 314 | + |
| 315 | +The number of parallel tuners can be passed to component by the `TuneArgs` as |
| 316 | +shown below: |
| 317 | + |
| 318 | +```python |
| 319 | +# Args specific to tuning. |
| 320 | +message TuneArgs { |
| 321 | + # Number of trials to run in parallel. |
| 322 | + # Each trial will be trained and evaluated by separate worker jobs. |
| 323 | + int32 num_parallel_trials = 1; |
| 324 | +} |
| 325 | + |
| 326 | +class TunerSpec(ComponentSpec): |
| 327 | + |
| 328 | + PARAMETERS = { |
| 329 | + ... |
| 330 | + 'tune_args': ExecutionParameter(type=tuner_pb2.TuneArgs), |
| 331 | + } |
| 332 | +``` |
| 333 | + |
| 334 | +The KerasTuner library allows users to config |
| 335 | +[`tf.distribute.Strategy`](https://www.tensorflow.org/tutorials/distribute/kerass) |
| 336 | +if they are using |
| 337 | +[`kerastuner.Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py) |
| 338 | +class (or subclasses of it). In above parallel tuning, each trial (each model |
| 339 | +training) is executed in a single worker, as such only single machine strategy |
| 340 | +is allowed. To support multi-worker distributed training, we need to be able to |
| 341 | +execute the trial (training) on different workers. |
| 342 | + |
| 343 | +At the time of writing, KerasTuner library can be used for parallel tuning with |
| 344 | +single machine `tf.distribute.Strategy`, e.g., |
| 345 | +[`MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy) |
| 346 | +, multi-worker strategy (distributed training for trial) support is on the |
| 347 | +roadmap (note that cluster managing is not part of the library). |
| 348 | + |
| 349 | +At the time of writing, TFX doesn’t have the ability to manage the multi-worker |
| 350 | +cluster and the centralized optimization service, so parallel tuning or |
| 351 | +distributed training is not supported natively in TFX (local or on-prem), but in |
| 352 | +the next section, we will discuss the integration for Google Cloud. Similar |
| 353 | +parallel tuning support can be built for other execution environments. |
| 354 | + |
| 355 | +### Google Cloud Integration |
| 356 | + |
| 357 | +In this section, we discuss the Tuner component with |
| 358 | +[Google Cloud AI Platform](https://cloud.google.com/ai-platform) (CAIP), |
| 359 | +specifically, an implementation of KerasTuner Oracle that talks to the |
| 360 | +[AI Platform Optimizer](https://cloud.google.com/ai-platform/optimizer/docs/overview) |
| 361 | +as the centralized optimization service, and a custom Tuner executor |
| 362 | +implementation that makes use of the Cloud Optimizer-based Oracle (symbol names |
| 363 | +are subject to change). |
| 364 | + |
| 365 | +As mentioned above in the parallel tuning section, KerasTuner uses a centralized |
| 366 | +optimization service that manages states of a tuning study and trials. In |
| 367 | +addition to that, we will create a `CloudOracle` as a client to the AI Platform |
| 368 | +Optimizer service, and a `CloudTuner` which inherits from Keras |
| 369 | +[Tuner](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py). |
| 370 | +In the module file, users create the `tuner_fn` with `CloudTuner`, and then |
| 371 | +users configure the TFX Tuner component to use the a custom Tuner executor |
| 372 | +(`CloudExecutor`), which launches multiple `CloudTuner`s on a Google Cloud AI |
| 373 | +Platform Training job with possibly multiple worker machines running various |
| 374 | +trials concurrently. Below shows the workflow for in process tuning and Cloud |
| 375 | +tuning. |
| 376 | + |
| 377 | +<div style="text-align:center"><img src='20200420-tfx-tuner-component/cloud.png', width='600'></div> |
| 378 | + |
| 379 | +## Future work |
| 380 | + |
| 381 | +* Native support for multi-worker parallel tuning. |
| 382 | +* Custom Tuner (inherits from BaseTuner) examples, e.g., for Estimator support |
| 383 | + or Keras custom training loop support. |
0 commit comments