|
| 1 | +# TPU SavedModel Export API for TF2.x |
| 2 | + |
| 3 | +| Status | Proposed | |
| 4 | +| :------------ | :------------------------------------------------------ | |
| 5 | +| **RFC #** | [NNN](https://github.com/tensorflow/community/pull/NNN) | |
| 6 | +: : (update when you have community PR #) : |
| 7 | + |
| 8 | +| **Sponsor ** | [email protected] | |
| 9 | +| **Updated** | 2019-11-06 | |
| 10 | + |
| 11 | +## Objective |
| 12 | + |
| 13 | +Provide an API to allow TF2 users to export TPU saved models <b>for |
| 14 | +inference</b>, which: |
| 15 | + |
| 16 | ++ Provide a user-friendly way to specify which function to run on TPU; |
| 17 | ++ Hides Graph construction and TPU inference specific logic (multi-core |
| 18 | + support, etc) from users; |
| 19 | ++ Allows specifying tags in SavedModel. |
| 20 | + |
| 21 | +## Motivation |
| 22 | + |
| 23 | +### Limitation of current `tf.saved_model.save()` |
| 24 | + |
| 25 | +MetaGraphDef allows saving customized tags. Current downstream components like |
| 26 | +TPU model-server, TFX infra-validator use the tags to load the specific |
| 27 | +MetaGraph. However tf.saved_model.save() does not allow users to specify the set |
| 28 | +of tags in MetaGraphDef, but hard-coded the MetaGraph to have only one ‘serve’ |
| 29 | +tag. |
| 30 | + |
| 31 | +### Special Logic in TPU Inference Graph |
| 32 | + |
| 33 | +Under the status quo, TPU computations have to be represented by a drastically |
| 34 | +different graph from CPU. Inference-specific requirements (e.g. batching / |
| 35 | +core-selection) also adds another layer of complexity. |
| 36 | + |
| 37 | +Some major differences between CPU and TPU Graph: |
| 38 | + |
| 39 | ++ As a protocol between TensorFlow Graph and TF2XLA, TPU device placement of a |
| 40 | + Node is done by attaching `_tpu_replicate` attribute; |
| 41 | ++ For multicore efficiency, TPU computations have to be encapsulated as a |
| 42 | + function and saved in FunctionLibrary, and will be called by |
| 43 | + TPUPartitionedCall. A TPUOrdinalSelector node has to be connected to |
| 44 | + TPUPartitionedCall to do efficient round-robin core selection; |
| 45 | ++ Variable nodes have to be lifted from TPU functions, rewritten as |
| 46 | + VarHandleOp, and consumed by ReadVariableOp. |
| 47 | + |
| 48 | +Also for reducing the number of TPU compilation, serving platforms(For example, |
| 49 | +Servomatic) prefers batching the inference requests with a few allowed batch |
| 50 | +sizes. This requires wrapping TPUPartitionedCall in another function, and called |
| 51 | +by BatchFunction. |
| 52 | + |
| 53 | +Below is an intuitive example of how a TPU graph is different from a CPU one: |
| 54 | + |
| 55 | + |
| 56 | +<center>Original CPU Graph.</center> |
| 57 | + |
| 58 | + |
| 59 | +<center>TPU Graph.</center> |
| 60 | + |
| 61 | +### User Control of Device Placement |
| 62 | + |
| 63 | +There has to be a way for users to specify which part of computation should be |
| 64 | +placed on TPU, because there’s no perfect device placement policy that can work |
| 65 | +for every use case. For example even though dense embedding ops are allowed on |
| 66 | +TPU, serving models might still want to run embedding lookups on CPU because the |
| 67 | +embeddings are too big to fit on TPU. |
| 68 | + |
| 69 | + |
| 70 | +<center>Example of user control. In this graph, both ‘custom_embedding’ and |
| 71 | +‘dense’ can run on TPU. But users want ‘custom_embedding’ to run on CPU for |
| 72 | +whatever reason, e.g. CPU computations can be parallelized, users don’t have |
| 73 | +enough TPU resources, etc. In this case, there has to be a way for them to tell |
| 74 | +SavedModel that only ‘dense’ is to run on TPU.</center> |
| 75 | + |
| 76 | +## User Benefit |
| 77 | + |
| 78 | +Enable TPU Inference. |
| 79 | + |
| 80 | +## Design Proposal |
| 81 | + |
| 82 | +### User Facing API |
| 83 | + |
| 84 | +<b>For General TF2 Users</b> |
| 85 | + |
| 86 | +Users need to do the following things to export a TPU SavedModel in TF2.x: |
| 87 | + |
| 88 | +1. Replace @tf.function with @tf.tpu.function for functions they wish to run on |
| 89 | + TPU; |
| 90 | + |
| 91 | + ```python |
| 92 | + @tf.tpu.function |
| 93 | + def predict_step(images): |
| 94 | + ... |
| 95 | + ``` |
| 96 | + |
| 97 | +2. Create main serving function and call the tpu function above. The main |
| 98 | + function might have additional TF ops which can’t run on TPU (e.g. |
| 99 | + `tf.decode_image`: |
| 100 | + |
| 101 | + ```python |
| 102 | + @tf.function |
| 103 | + def serve(images): |
| 104 | + image_tensors = tf.decode_image(images) |
| 105 | + return predict_step(image_tensors) |
| 106 | + ``` |
| 107 | + |
| 108 | + And then create a signature: |
| 109 | + |
| 110 | + ```python |
| 111 | + signatures = { |
| 112 | + 'serving_default': |
| 113 | + serve.get_concrete_function(...), |
| 114 | + } |
| 115 | + tags = [tag_constants.SERVING, tag_constants.TPU] |
| 116 | + ``` |
| 117 | + |
| 118 | +3. Pass the both signatures to `tf.saved_model.save()`: |
| 119 | + |
| 120 | + ```python |
| 121 | + tf.saved_model.save( |
| 122 | + model, |
| 123 | + export_dir='...', |
| 124 | + signatures=signatures, |
| 125 | + tags=tags) |
| 126 | + ``` |
| 127 | + |
| 128 | +The resulting TPU inference graph looks like this: |
| 129 | + |
| 130 | + |
| 131 | +<center>Resulting TPU Graph.</center> |
| 132 | + |
| 133 | +<b>For Advanced Users who need customized Ops</b> |
| 134 | + |
| 135 | +In such cases, we provide the flexibility for users to tweak `@tf.tpu.function`. |
| 136 | + |
| 137 | +1. If users wish not to use TPUPartitionedCall, they can disable using |
| 138 | + TPUPartitionedCall: |
| 139 | + |
| 140 | + ```python |
| 141 | + @tf.tpu.function(use_tpu_partitioned_call=False) |
| 142 | + def predict_step(images): |
| 143 | + ... |
| 144 | + ``` |
| 145 | + |
| 146 | +2. Users can nest TPU functions within BatchFunction: |
| 147 | + |
| 148 | + ```python |
| 149 | + @batch_ops.nondifferentiable_batch_function |
| 150 | + @tf.tpu.function |
| 151 | + def predict_step(images): |
| 152 | + ... |
| 153 | + ``` |
| 154 | + |
| 155 | +3. User can also use their customized PartitionedCallOp: |
| 156 | + |
| 157 | + ```python |
| 158 | + @batch_ops.nondifferentiable_batch_function |
| 159 | + @my_partitioned_call_op_constructor |
| 160 | + @tf.tpu.function(use_tpu_partitioned_call=False) |
| 161 | + def predict_step(images): |
| 162 | + ... |
| 163 | + ``` |
| 164 | + |
| 165 | +<b>For Keras Users</b> |
| 166 | + |
| 167 | +Keras users only need to pass `export_to_tpu=True` to save to TPU SavedModel. |
| 168 | +(Currently, we require the Keras model being saved to be completely |
| 169 | +TPU-compatible.) |
| 170 | + |
| 171 | +```python |
| 172 | +tf.keras.models.save_model( |
| 173 | + model, |
| 174 | + filepath='...', |
| 175 | + export_to_tpu=True) |
| 176 | +``` |
| 177 | + |
| 178 | +### Changes to TF2.x API |
| 179 | + |
| 180 | +1. In addition to taking the keyword argument `signatures`, |
| 181 | + tf.saved_model.save() will take an optional argument `tags`. |
| 182 | + |
| 183 | + Originally, concrete functions specified by `signatures` will be saved in |
| 184 | + one MetaGraph, which has ‘serve’ tag hard-coded. |
| 185 | + |
| 186 | + `tags` is an optional argument. It is a Python iterable, representing the |
| 187 | + list of tags for MetaGraph. This allows user to specify customized tags. |
| 188 | + |
| 189 | +2. Implement an additional `@tf.tpu.function` decorator in |
| 190 | + `tensorflow/python/tpu/tpu.py`. This decorator handles TPU rewriting under |
| 191 | + the hood. |
| 192 | + |
| 193 | +3. An additional `use_tpu_partitioned_call` keyword argument for |
| 194 | + `def_function.function()` and `Function.__init__()`. This argument will be |
| 195 | + passed through to the place where PartitionedCallOp is created. Originally |
| 196 | + all stateful functions will generate StatefulPartitionedCallOp. Now we |
| 197 | + switch to TPUPartitionedCallOp, and this routing is done by checking the |
| 198 | + value of `use_tpu_partitioned_call`. |
| 199 | + |
| 200 | +### Changes to Keras API |
| 201 | + |
| 202 | +Keras users would like `tf.keras.models.save_model()` to work directly for |
| 203 | +exporting TPU SavedModel, without having knowledge of tf.function / tags / |
| 204 | +signatures. The only way to achieve this is to hide those logics under |
| 205 | +`tf.keras.models.save_model()`. |
| 206 | + |
| 207 | +After the change, `tf.keras.models.save_model()` will have two additional |
| 208 | +arguments: |
| 209 | + |
| 210 | +1. `export_to_tpu`: Simply setting this to `True` will export TPU model; |
| 211 | +2. `tags_signatures`: Optionally for advanced users, if they want to have more |
| 212 | + control of what tags / signatures they are using, they can use this argument |
| 213 | + as if they are using TF2.x saving API. |
| 214 | + |
| 215 | +## Detailed Design |
| 216 | + |
| 217 | +### TF2.x API |
| 218 | + |
| 219 | +Under the hood, exporter API is doing the following things: |
| 220 | + |
| 221 | ++ The @tf.tpu.function wraps user-specified function; |
| 222 | ++ `use_tpu_partitioned_call` as an attribute in Function class is controlling |
| 223 | + whether TPUPartitionedCall is generated instead of StatefulPartitionedCall; |
| 224 | ++ Tag the MetaGraph with user-defined tags. |
| 225 | + |
| 226 | +<b>Step 1:</b> Use a new decorator to wrap TPU version of the user-specified TPU |
| 227 | +function. It calls tpu.rewrite inside the original function to generate a TPU |
| 228 | +version of graph. By default, this will create a tpu function. If users wish to |
| 229 | +preserve both CPU and TPU function, they can set `preserve_cpu_fn=True`. |
| 230 | + |
| 231 | +```python |
| 232 | +# tensorflow/python/tpu/tpu.py |
| 233 | + |
| 234 | +FunctionCollection = namedtuple('FunctionCollection', ['tpu_fn', 'cpu_fn']) |
| 235 | + |
| 236 | +def _rewrite_func_wrapper(func): |
| 237 | + def tpu_fn(*x): |
| 238 | + return rewrite(func, x) |
| 239 | + return tpu_fn |
| 240 | + |
| 241 | +@tf_export("tpu.function") |
| 242 | +def tpu_function(func=None, *args, **kwargs): |
| 243 | + """Compiles a TPU function into a callable TensorFlow graph.""" |
| 244 | + def inner_func(func): |
| 245 | + preserve_cpu_fn = False |
| 246 | + if 'preserve_cpu_fn' in kwargs: |
| 247 | + preserve_cpu_fn = kwargs['preserve_cpu_fn'] |
| 248 | + del kwargs['preserve_cpu_fn'] |
| 249 | + |
| 250 | + if preserve_cpu_fn: |
| 251 | + cpu_fn = def_function.function(func, *args, **kwargs) |
| 252 | + |
| 253 | + kwargs.update({'use_tpu_partitioned_call': True}) |
| 254 | + tpu_func = _rewrite_func_wrapper(func) |
| 255 | + tpu_fn = def_function.function(tpu_func, *args, **kwargs) |
| 256 | + |
| 257 | + if preserve_cpu_fn: |
| 258 | + func_collection = FunctionCollection(tpu_fn=tpu_fn, cpu_fn=cpu_fn) |
| 259 | + return func_collection |
| 260 | + else: |
| 261 | + return tpu_fn |
| 262 | + |
| 263 | + if func: |
| 264 | + return inner_func(func) |
| 265 | + else: |
| 266 | + return inner_func |
| 267 | +``` |
| 268 | + |
| 269 | +<b>Step 2:</b> Pass the `use_tpu_partitioned_call` argument all the way through |
| 270 | +to `functional_ops.py`, where TPUPartitionedCall will be created, instead of |
| 271 | +StatefulPartitionedCall. |
| 272 | + |
| 273 | +```python |
| 274 | +# tensorflow/python/ops/functional_ops.py |
| 275 | + |
| 276 | +if hasattr(f, "_use_tpu_partitioned_call") and f._use_tpu_partitioned_call: |
| 277 | + outputs = tpu_functional.TPUPartitionedCall( |
| 278 | + args=args, |
| 279 | + device_ordinal=tpu_ops.tpu_ordinal_selector(), |
| 280 | + Tout=tout, |
| 281 | + f=f) |
| 282 | +``` |
| 283 | + |
| 284 | +<b>Step 3:</b> Create MetaGraph for SavedModel. |
| 285 | + |
| 286 | +```python |
| 287 | +# tensorflow/python/saved_model/save.py |
| 288 | + |
| 289 | +saved_model = saved_model_pb2.SavedModel() |
| 290 | +... |
| 291 | +meta_graph_def = saved_model.meta_graphs.add() |
| 292 | +asset_info, exported_graph = _fill_meta_graph_def( |
| 293 | + meta_graph_def, saveable_view, signatures, |
| 294 | + options.namespace_whitelist, |
| 295 | + tags=list(tags)) |
| 296 | +... |
| 297 | +``` |
| 298 | + |
| 299 | +### Support for Keras saving API |
| 300 | + |
| 301 | +Adding an argument `export_to_tpu` for `tf.keras.models.save_model()`, which if |
| 302 | +set to true will rewrite the model for TPU inference. |
| 303 | + |
| 304 | +Adding an argument `tags` for `tf.keras.models.save_model()` which has the same |
| 305 | +semantics as that in `tf.saved_model.save()`. |
| 306 | + |
| 307 | +```python |
| 308 | +# tensorflow/python/keras/saving/save.py |
| 309 | + |
| 310 | +@keras_export('keras.models.save_model') |
| 311 | +def save_model(model, |
| 312 | + filepath, |
| 313 | + overwrite=True, |
| 314 | + include_optimizer=True, |
| 315 | + save_format=None, |
| 316 | + signatures=None, |
| 317 | + tags=None, |
| 318 | + export_to_tpu=False, |
| 319 | + options=None): |
| 320 | + ... |
| 321 | + if (export_to_tpu and |
| 322 | + (not tags |
| 323 | + or tag_constants.TPU not in tags)): |
| 324 | + checkpoint_graph_view = save_lib._AugmentedGraphView(model) |
| 325 | + signatures = find_function_to_export_tpu(checkpoint_graph_view) |
| 326 | + tags = [tag_constants.SERVING, tag_constants.TPU] |
| 327 | + |
| 328 | + saved_model_save.save(model, filepath, overwrite, |
| 329 | + include_optimizer, |
| 330 | + signatures, |
| 331 | + tags, |
| 332 | + options) |
| 333 | +``` |
0 commit comments