Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 523f495

Browse files
committed
TF2 TPU SavedModel RFC Doc Draft
1 parent da0b733 commit 523f495

File tree

5 files changed

+333
-0
lines changed

5 files changed

+333
-0
lines changed
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# TPU SavedModel Export API for TF2.x
2+
3+
| Status | Proposed |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [NNN](https://github.com/tensorflow/community/pull/NNN) |
6+
: : (update when you have community PR #) :
7+
8+
| **Sponsor** | [email protected] |
9+
| **Updated** | 2019-11-06 |
10+
11+
## Objective
12+
13+
Provide an API to allow TF2 users to export TPU saved models <b>for
14+
inference</b>, which:
15+
16+
+ Provide a user-friendly way to specify which function to run on TPU;
17+
+ Hides Graph construction and TPU inference specific logic (multi-core
18+
support, etc) from users;
19+
+ Allows specifying tags in SavedModel.
20+
21+
## Motivation
22+
23+
### Limitation of current `tf.saved_model.save()`
24+
25+
MetaGraphDef allows saving customized tags. Current downstream components like
26+
TPU model-server, TFX infra-validator use the tags to load the specific
27+
MetaGraph. However tf.saved_model.save() does not allow users to specify the set
28+
of tags in MetaGraphDef, but hard-coded the MetaGraph to have only one ‘serve’
29+
tag.
30+
31+
### Special Logic in TPU Inference Graph
32+
33+
Under the status quo, TPU computations have to be represented by a drastically
34+
different graph from CPU. Inference-specific requirements (e.g. batching /
35+
core-selection) also adds another layer of complexity.
36+
37+
Some major differences between CPU and TPU Graph:
38+
39+
+ As a protocol between TensorFlow Graph and TF2XLA, TPU device placement of a
40+
Node is done by attaching `_tpu_replicate` attribute;
41+
+ For multicore efficiency, TPU computations have to be encapsulated as a
42+
function and saved in FunctionLibrary, and will be called by
43+
TPUPartitionedCall. A TPUOrdinalSelector node has to be connected to
44+
TPUPartitionedCall to do efficient round-robin core selection;
45+
+ Variable nodes have to be lifted from TPU functions, rewritten as
46+
VarHandleOp, and consumed by ReadVariableOp.
47+
48+
Also for reducing the number of TPU compilation, serving platforms(For example,
49+
Servomatic) prefers batching the inference requests with a few allowed batch
50+
sizes. This requires wrapping TPUPartitionedCall in another function, and called
51+
by BatchFunction.
52+
53+
Below is an intuitive example of how a TPU graph is different from a CPU one:
54+
55+
![Original CPU Graph](20191106-tf2-tpu-savedmodel/cpu_graph.png)
56+
<center>Original CPU Graph.</center>
57+
58+
![TPU Graph](20191106-tf2-tpu-savedmodel/tpu_graph.png)
59+
<center>TPU Graph.</center>
60+
61+
### User Control of Device Placement
62+
63+
There has to be a way for users to specify which part of computation should be
64+
placed on TPU, because there’s no perfect device placement policy that can work
65+
for every use case. For example even though dense embedding ops are allowed on
66+
TPU, serving models might still want to run embedding lookups on CPU because the
67+
embeddings are too big to fit on TPU.
68+
69+
![Customized Embeddings](20191106-tf2-tpu-savedmodel/customized_embeddings.png)
70+
<center>Example of user control. In this graph, both ‘custom_embedding’ and
71+
‘dense’ can run on TPU. But users want ‘custom_embedding’ to run on CPU for
72+
whatever reason, e.g. CPU computations can be parallelized, users don’t have
73+
enough TPU resources, etc. In this case, there has to be a way for them to tell
74+
SavedModel that only ‘dense’ is to run on TPU.</center>
75+
76+
## User Benefit
77+
78+
Enable TPU Inference.
79+
80+
## Design Proposal
81+
82+
### User Facing API
83+
84+
<b>For General TF2 Users</b>
85+
86+
Users need to do the following things to export a TPU SavedModel in TF2.x:
87+
88+
1. Replace @tf.function with @tf.tpu.function for functions they wish to run on
89+
TPU;
90+
91+
```python
92+
@tf.tpu.function
93+
def predict_step(images):
94+
...
95+
```
96+
97+
2. Create main serving function and call the tpu function above. The main
98+
function might have additional TF ops which can’t run on TPU (e.g.
99+
`tf.decode_image`:
100+
101+
```python
102+
@tf.function
103+
def serve(images):
104+
image_tensors = tf.decode_image(images)
105+
return predict_step(image_tensors)
106+
```
107+
108+
And then create a signature:
109+
110+
```python
111+
signatures = {
112+
'serving_default':
113+
serve.get_concrete_function(...),
114+
}
115+
tags = [tag_constants.SERVING, tag_constants.TPU]
116+
```
117+
118+
3. Pass the both signatures to `tf.saved_model.save()`:
119+
120+
```python
121+
tf.saved_model.save(
122+
model,
123+
export_dir='...',
124+
signatures=signatures,
125+
tags=tags)
126+
```
127+
128+
The resulting TPU inference graph looks like this:
129+
130+
![Resulting TPU Graph](20191106-tf2-tpu-savedmodel/tpu_result.png)
131+
<center>Resulting TPU Graph.</center>
132+
133+
<b>For Advanced Users who need customized Ops</b>
134+
135+
In such cases, we provide the flexibility for users to tweak `@tf.tpu.function`.
136+
137+
1. If users wish not to use TPUPartitionedCall, they can disable using
138+
TPUPartitionedCall:
139+
140+
```python
141+
@tf.tpu.function(use_tpu_partitioned_call=False)
142+
def predict_step(images):
143+
...
144+
```
145+
146+
2. Users can nest TPU functions within BatchFunction:
147+
148+
```python
149+
@batch_ops.nondifferentiable_batch_function
150+
@tf.tpu.function
151+
def predict_step(images):
152+
...
153+
```
154+
155+
3. User can also use their customized PartitionedCallOp:
156+
157+
```python
158+
@batch_ops.nondifferentiable_batch_function
159+
@my_partitioned_call_op_constructor
160+
@tf.tpu.function(use_tpu_partitioned_call=False)
161+
def predict_step(images):
162+
...
163+
```
164+
165+
<b>For Keras Users</b>
166+
167+
Keras users only need to pass `export_to_tpu=True` to save to TPU SavedModel.
168+
(Currently, we require the Keras model being saved to be completely
169+
TPU-compatible.)
170+
171+
```python
172+
tf.keras.models.save_model(
173+
model,
174+
filepath='...',
175+
export_to_tpu=True)
176+
```
177+
178+
### Changes to TF2.x API
179+
180+
1. In addition to taking the keyword argument `signatures`,
181+
tf.saved_model.save() will take an optional argument `tags`.
182+
183+
Originally, concrete functions specified by `signatures` will be saved in
184+
one MetaGraph, which has ‘serve’ tag hard-coded.
185+
186+
`tags` is an optional argument. It is a Python iterable, representing the
187+
list of tags for MetaGraph. This allows user to specify customized tags.
188+
189+
2. Implement an additional `@tf.tpu.function` decorator in
190+
`tensorflow/python/tpu/tpu.py`. This decorator handles TPU rewriting under
191+
the hood.
192+
193+
3. An additional `use_tpu_partitioned_call` keyword argument for
194+
`def_function.function()` and `Function.__init__()`. This argument will be
195+
passed through to the place where PartitionedCallOp is created. Originally
196+
all stateful functions will generate StatefulPartitionedCallOp. Now we
197+
switch to TPUPartitionedCallOp, and this routing is done by checking the
198+
value of `use_tpu_partitioned_call`.
199+
200+
### Changes to Keras API
201+
202+
Keras users would like `tf.keras.models.save_model()` to work directly for
203+
exporting TPU SavedModel, without having knowledge of tf.function / tags /
204+
signatures. The only way to achieve this is to hide those logics under
205+
`tf.keras.models.save_model()`.
206+
207+
After the change, `tf.keras.models.save_model()` will have two additional
208+
arguments:
209+
210+
1. `export_to_tpu`: Simply setting this to `True` will export TPU model;
211+
2. `tags_signatures`: Optionally for advanced users, if they want to have more
212+
control of what tags / signatures they are using, they can use this argument
213+
as if they are using TF2.x saving API.
214+
215+
## Detailed Design
216+
217+
### TF2.x API
218+
219+
Under the hood, exporter API is doing the following things:
220+
221+
+ The @tf.tpu.function wraps user-specified function;
222+
+ `use_tpu_partitioned_call` as an attribute in Function class is controlling
223+
whether TPUPartitionedCall is generated instead of StatefulPartitionedCall;
224+
+ Tag the MetaGraph with user-defined tags.
225+
226+
<b>Step 1:</b> Use a new decorator to wrap TPU version of the user-specified TPU
227+
function. It calls tpu.rewrite inside the original function to generate a TPU
228+
version of graph. By default, this will create a tpu function. If users wish to
229+
preserve both CPU and TPU function, they can set `preserve_cpu_fn=True`.
230+
231+
```python
232+
# tensorflow/python/tpu/tpu.py
233+
234+
FunctionCollection = namedtuple('FunctionCollection', ['tpu_fn', 'cpu_fn'])
235+
236+
def _rewrite_func_wrapper(func):
237+
def tpu_fn(*x):
238+
return rewrite(func, x)
239+
return tpu_fn
240+
241+
@tf_export("tpu.function")
242+
def tpu_function(func=None, *args, **kwargs):
243+
"""Compiles a TPU function into a callable TensorFlow graph."""
244+
def inner_func(func):
245+
preserve_cpu_fn = False
246+
if 'preserve_cpu_fn' in kwargs:
247+
preserve_cpu_fn = kwargs['preserve_cpu_fn']
248+
del kwargs['preserve_cpu_fn']
249+
250+
if preserve_cpu_fn:
251+
cpu_fn = def_function.function(func, *args, **kwargs)
252+
253+
kwargs.update({'use_tpu_partitioned_call': True})
254+
tpu_func = _rewrite_func_wrapper(func)
255+
tpu_fn = def_function.function(tpu_func, *args, **kwargs)
256+
257+
if preserve_cpu_fn:
258+
func_collection = FunctionCollection(tpu_fn=tpu_fn, cpu_fn=cpu_fn)
259+
return func_collection
260+
else:
261+
return tpu_fn
262+
263+
if func:
264+
return inner_func(func)
265+
else:
266+
return inner_func
267+
```
268+
269+
<b>Step 2:</b> Pass the `use_tpu_partitioned_call` argument all the way through
270+
to `functional_ops.py`, where TPUPartitionedCall will be created, instead of
271+
StatefulPartitionedCall.
272+
273+
```python
274+
# tensorflow/python/ops/functional_ops.py
275+
276+
if hasattr(f, "_use_tpu_partitioned_call") and f._use_tpu_partitioned_call:
277+
outputs = tpu_functional.TPUPartitionedCall(
278+
args=args,
279+
device_ordinal=tpu_ops.tpu_ordinal_selector(),
280+
Tout=tout,
281+
f=f)
282+
```
283+
284+
<b>Step 3:</b> Create MetaGraph for SavedModel.
285+
286+
```python
287+
# tensorflow/python/saved_model/save.py
288+
289+
saved_model = saved_model_pb2.SavedModel()
290+
...
291+
meta_graph_def = saved_model.meta_graphs.add()
292+
asset_info, exported_graph = _fill_meta_graph_def(
293+
meta_graph_def, saveable_view, signatures,
294+
options.namespace_whitelist,
295+
tags=list(tags))
296+
...
297+
```
298+
299+
### Support for Keras saving API
300+
301+
Adding an argument `export_to_tpu` for `tf.keras.models.save_model()`, which if
302+
set to true will rewrite the model for TPU inference.
303+
304+
Adding an argument `tags` for `tf.keras.models.save_model()` which has the same
305+
semantics as that in `tf.saved_model.save()`.
306+
307+
```python
308+
# tensorflow/python/keras/saving/save.py
309+
310+
@keras_export('keras.models.save_model')
311+
def save_model(model,
312+
filepath,
313+
overwrite=True,
314+
include_optimizer=True,
315+
save_format=None,
316+
signatures=None,
317+
tags=None,
318+
export_to_tpu=False,
319+
options=None):
320+
...
321+
if (export_to_tpu and
322+
(not tags
323+
or tag_constants.TPU not in tags)):
324+
checkpoint_graph_view = save_lib._AugmentedGraphView(model)
325+
signatures = find_function_to_export_tpu(checkpoint_graph_view)
326+
tags = [tag_constants.SERVING, tag_constants.TPU]
327+
328+
saved_model_save.save(model, filepath, overwrite,
329+
include_optimizer,
330+
signatures,
331+
tags,
332+
options)
333+
```
58.7 KB
Loading
45.2 KB
Loading
105 KB
Loading
48.3 KB
Loading

0 commit comments

Comments
 (0)