@@ -3,40 +3,65 @@ keras.Model.export(
33 self,
44 filepath,
55 format='tf_saved_model',
6- verbose=True
6+ verbose=True,
7+ input_signature=None,
8+ **kwargs
79)
810__doc__
9- Create a TF SavedModel artifact for inference.
11+ Export the model as an artifact for inference.
1012
11- **Note:** This can currently only be used with
12- the TensorFlow or JAX backends.
13-
14- This method lets you export a model to a lightweight SavedModel artifact
15- that contains the model's forward pass only (its `call()` method)
16- and can be served via e.g. TF-Serving. The forward pass is registered
17- under the name `serve()` (see example below).
13+ Args:
14+ filepath: `str` or `pathlib.Path` object. The path to save the
15+ artifact.
16+ format: `str`. The export format. Supported values:
17+ `"tf_saved_model"` and `"onnx"`. Defaults to
18+ `"tf_saved_model"`.
19+ verbose: `bool`. Whether to print a message during export. Defaults
20+ to `True`.
21+ input_signature: Optional. Specifies the shape and dtype of the
22+ model inputs. Can be a structure of `keras.InputSpec`,
23+ `tf.TensorSpec`, `backend.KerasTensor`, or backend tensor. If
24+ not provided, it will be automatically computed. Defaults to
25+ `None`.
26+ **kwargs: Additional keyword arguments:
27+ - Specific to the JAX backend and `format="tf_saved_model"`:
28+ - `is_static`: Optional `bool`. Indicates whether `fn` is
29+ static. Set to `False` if `fn` involves state updates
30+ (e.g., RNG seeds and counters).
31+ - `jax2tf_kwargs`: Optional `dict`. Arguments for
32+ `jax2tf.convert`. See the documentation for
33+ [`jax2tf.convert`](
34+ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
35+ If `native_serialization` and `polymorphic_shapes` are
36+ not provided, they will be automatically computed.
1837
19- The original code of the model (including any custom layers you may
20- have used) is *no longer* necessary to reload the artifact -- it is
21- entirely standalone.
38+ **Note:** This feature is currently supported only with TensorFlow, JAX
39+ and Torch backends.
2240
23- Args:
24- filepath: `str` or `pathlib.Path` object. Path where to save
25- the artifact.
26- verbose: whether to print all the variables of the exported model.
41+ Examples:
2742
28- Example:
43+ Here's how to export a TensorFlow SavedModel for inference.
2944
3045```python
31- # Create the artifact
32- model.export("path/to/location")
46+ # Export the model as a TensorFlow SavedModel artifact
47+ model.export("path/to/location", format="tf_saved_model" )
3348
34- # Later, in a different process/environment...
49+ # Load the artifact in a different process/environment
3550reloaded_artifact = tf.saved_model.load("path/to/location")
3651predictions = reloaded_artifact.serve(input_data)
3752```
3853
39- If you would like to customize your serving endpoints, you can
40- use the lower-level `keras.export.ExportArchive` class. The
41- `export()` method relies on `ExportArchive` internally.
54+ Here's how to export an ONNX for inference.
55+
56+ ```python
57+ # Export the model as a ONNX artifact
58+ model.export("path/to/location", format="onnx")
59+
60+ # Load the artifact in a different process/environment
61+ ort_session = onnxruntime.InferenceSession("path/to/location")
62+ ort_inputs = {
63+ k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
64+ }
65+ predictions = ort_session.run(None, ort_inputs)
66+ ```
4267
0 commit comments