@@ -27,13 +27,26 @@ onnx-diagnostic: investigate onnx models
2727
2828The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches >`_:
2929it helps exporting **pytorch models into ONNX **, mostly designed for LLMs using dynamic caches.
30+ Patches can be enabled as follows:
3031
3132.. code-block :: python
3233
34+ from onnx_diagnostic.torch_export_patches import torch_export_patches
35+
3336 with torch_export_patches(patch_transformers = True ) as f:
3437 ep = torch.export.export(model, args, kwargs = kwargs, dynamic_shapes = dynamic_shapes)
3538 # ...
3639
40+ Dynamic shapes are difficult to guess for caches, one function
41+ returns a structure defining all dimensions as dynamic.
42+ You need then to remove those which are not dynamic in your model.
43+
44+ .. code-block :: python
45+
46+ from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
47+
48+ dynamic_shapes = all_dynamic_shape_from_inputs(cache)
49+
3750 It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
3851See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/ >`_ and
3952`torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches >`_.
@@ -90,14 +103,26 @@ Snapshot of usefuls tools
90103
91104.. code-block :: python
92105
106+ from onnx_diagnostic.torch_export_patches import torch_export_patches
107+
93108 with torch_export_patches(patch_transformers = True ) as f:
94109 ep = torch.export.export(model, args, kwargs = kwargs, dynamic_shapes = dynamic_shapes)
95110 # ...
96111
112+ **all_dynamic_shape_from_inputs **
113+
114+ .. code-block :: python
115+
116+ from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
117+
118+ dynamic_shapes = all_dynamic_shape_from_inputs(cache)
119+
97120 **torch_export_rewrite **
98121
99122.. code-block :: python
100123
124+ from onnx_diagnostic.torch_export_patches import torch_export_rewrite
125+
101126 with torch_export_rewrite(rewrite = [Model.forward]) as f:
102127 ep = torch.export.export(model, args, kwargs = kwargs, dynamic_shapes = dynamic_shapes)
103128 # ...
0 commit comments