Skip to content

Commit 5dd7775

Browse files
authored
Enable strings in guess_dynamic_shapes (#91)
* Enable strings in guess_dynamic_shapes * fix * doc
1 parent 750b6dd commit 5dd7775

File tree

5 files changed

+229
-46
lines changed

5 files changed

+229
-46
lines changed

README.rst

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Enlightening Examples
6767
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
6868
* `Find and fix an export issue due to dynamic shapes
6969
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_locate_issue.html>`_
70-
* `Export with DynamicCache and dynamic shapes
70+
* `Export with DynamicCache and guessed dynamic shapes
7171
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
7272
* `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
7373
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
@@ -95,10 +95,7 @@ Snapshot of usefuls tools
9595
9696
inputs = (
9797
torch.rand((3, 4), dtype=torch.float16),
98-
[
99-
torch.rand((5, 6), dtype=torch.float16),
100-
torch.rand((5, 6, 7), dtype=torch.float16),
101-
]
98+
[torch.rand((5, 6), dtype=torch.float16), torch.rand((5, 6, 7), dtype=torch.float16)],
10299
)
103100
104101
# with shapes
@@ -126,4 +123,33 @@ Snapshot of usefuls tools
126123

127124
**max_diff**
128125

129-
Returns the maximum discrancies across nested containers containing tensors.
126+
.. code-block:: python
127+
128+
import torch
129+
from onnx_diagnostic.helpers import max_diff
130+
131+
print(
132+
max_diff(
133+
(torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
134+
(torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
135+
)
136+
)
137+
138+
::
139+
140+
>>> {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}s
141+
142+
**guess_dynamic_shapes**
143+
144+
.. code-block:: python
145+
146+
inputs = [
147+
(torch.randn((5, 6)), torch.randn((1, 6))),
148+
(torch.randn((7, 8)), torch.randn((1, 8))),
149+
]
150+
ds = ModelInputs(model, inputs).guess_dynamic_shapes(auto="dim")
151+
print(ds)
152+
153+
::
154+
155+
>>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {})

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
.. _l-plot-export-with-dynamic-shape:
33
4-
===========================================
5-
Export with DynamicCache and dynamic shapes
6-
===========================================
4+
===================================================
5+
Export with DynamicCache and guessed dynamic shapes
6+
===================================================
77
88
Every LLMs implemented in :epkg:`transformers` use cache.
99
One of the most used is :class:`transformers.cache_utils.DynamicCache`.
@@ -84,6 +84,8 @@ def forward(self, cache, z):
8484
print(string_type(inputs[1], with_shape=True))
8585

8686
# %%
87+
# .. _l-guess-dynamic-shapes-example:
88+
#
8789
# Guess the dynamic shapes
8890
# ========================
8991
#
@@ -112,6 +114,17 @@ def forward(self, cache, z):
112114
)
113115
print(ep)
114116

117+
# %%
118+
# Use string instead of DYNAMIC
119+
# +++++++++++++++++++++++++++++
120+
#
121+
# ONNX exporter considers strings instead of DYNAMIC or AUTO
122+
# to give names to every dimension.
123+
124+
dss = mi.guess_dynamic_shapes(auto="dim")
125+
pprint.pprint(dss)
126+
127+
115128
# %%
116129
# Do we need to guess?
117130
# ++++++++++++++++++++

_doc/index.rst

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
========================================
22
onnx-diagnostic: investigate onnx models
33
========================================
44

@@ -48,7 +48,7 @@ It also implements tools to investigate, validate exported models (ExportedProgr
4848
license
4949

5050
Getting started
51-
+++++++++++++++
51+
===============
5252

5353
::
5454

@@ -63,7 +63,7 @@ or
6363
pip install onnx-diagnostic
6464

6565
Enlightening Examples
66-
+++++++++++++++++++++
66+
=====================
6767

6868
**Where to start to export a model**
6969

@@ -85,7 +85,13 @@ Enlightening Examples
8585
* :ref:`l-plot-failing-onnxruntime-evaluator`
8686
* :ref:`l-plot-failing-model-extract`
8787

88-
**Some Usefuls Tools**
88+
Some Usefuls Tools
89+
==================
90+
91+
string_type
92+
+++++++++++
93+
94+
See :func:`onnx_diagnostic.helpers.string_type`.
8995

9096
.. code-block:: python
9197
@@ -107,6 +113,11 @@ Enlightening Examples
107113

108114
>>> (T10s3x4,#2[T10s5x6,T10s5x6x7])
109115

116+
onnx_dtype_name
117+
+++++++++++++++
118+
119+
See :func:`onnx_diagnostic.helpers.onnx_dtype_name`.
120+
110121
.. code-block:: python
111122
112123
import onnx
@@ -121,7 +132,64 @@ Enlightening Examples
121132
>>> BFLOAT16
122133
>>> INT64
123134

124-
:func:`onnx_diagnostic.helpers.max_diff`, ...
135+
max_diff
136+
++++++++
137+
138+
See :func:`onnx_diagnostic.helpers.max_diff`.
139+
140+
.. code-block:: python
141+
142+
import torch
143+
from onnx_diagnostic.helpers import max_diff
144+
145+
print(
146+
max_diff(
147+
(torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
148+
(torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
149+
)
150+
)
151+
152+
::
153+
154+
>>> {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}s
155+
156+
guess_dynamic_shapes
157+
++++++++++++++++++++
158+
159+
See :meth:`onnx_diagnostic.export.ModelInputs.guess_dynamic_shapes`.
160+
161+
.. code-block:: python
162+
163+
inputs = [
164+
(torch.randn((5, 6)), torch.randn((1, 6))),
165+
(torch.randn((7, 8)), torch.randn((1, 8))),
166+
]
167+
ds = ModelInputs(model, inputs).guess_dynamic_shapes(auto="dim")
168+
print(ds)
169+
170+
::
171+
172+
>>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {})
173+
174+
use_dyn_for_str
175+
+++++++++++++++
176+
177+
178+
179+
Older versions
180+
++++++++++++++
181+
182+
* `0.5.0 <../v0.5.0/index.html>`_
183+
* `0.4.4 <../v0.4.4/index.html>`_
184+
* `0.4.3 <../v0.4.3/index.html>`_
185+
* `0.4.2 <../v0.4.2/index.html>`_
186+
* `0.4.1 <../v0.4.1/index.html>`_
187+
* `0.4.0 <../v0.4.0/index.html>`_
188+
* `0.3.0 <../v0.3.0/index.html>`_
189+
* `0.2.2 <../v0.2.2/index.html>`_
190+
* `0.2.1 <../v0.2.1/index.html>`_
191+
* `0.2.0 <../v0.2.0/index.html>`_
192+
* `0.1.0 <../v0.1.0/index.html>`_
125193

126194
The documentation was updated on:
127195

@@ -173,18 +241,3 @@ Size of the package:
173241
df = pandas.DataFrame(statistics_on_folder(os.path.dirname(__file__), aggregation=1))
174242
gr = df[["dir", "ext", "lines", "chars"]].groupby(["ext", "dir"]).sum()
175243
print(gr)
176-
177-
Older versions
178-
++++++++++++++
179-
180-
* `0.5.0 <../v0.5.0/index.html>`_
181-
* `0.4.4 <../v0.4.4/index.html>`_
182-
* `0.4.3 <../v0.4.3/index.html>`_
183-
* `0.4.2 <../v0.4.2/index.html>`_
184-
* `0.4.1 <../v0.4.1/index.html>`_
185-
* `0.4.0 <../v0.4.0/index.html>`_
186-
* `0.3.0 <../v0.3.0/index.html>`_
187-
* `0.2.2 <../v0.2.2/index.html>`_
188-
* `0.2.1 <../v0.2.1/index.html>`_
189-
* `0.2.0 <../v0.2.0/index.html>`_
190-
* `0.1.0 <../v0.1.0/index.html>`_

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,66 @@ def forward(self, cache, z):
470470
),
471471
)
472472

473+
@requires_transformers("4.51")
474+
def test_guess_dynamic_shapes_cache_str(self):
475+
class Model(torch.nn.Module):
476+
def forward(self, cache, z):
477+
return (
478+
z
479+
+ cache.key_cache[0]
480+
+ cache.key_cache[1]
481+
+ cache.value_cache[0]
482+
+ cache.value_cache[1]
483+
)
484+
485+
model = Model()
486+
487+
n_layers = 2
488+
bsize, nheads, slen, dim = 2, 4, 3, 7
489+
cache = make_dynamic_cache(
490+
[
491+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
492+
for i in range(n_layers)
493+
]
494+
)
495+
z = torch.randn((1, 1, 1, 7))
496+
model(cache, z)
497+
498+
cache2 = make_dynamic_cache(
499+
[
500+
(
501+
torch.randn(bsize, nheads, slen, dim),
502+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
503+
)
504+
for i in range(n_layers)
505+
]
506+
)
507+
inputs = [
508+
(cache, z),
509+
(cache2, torch.randn((1, 1, 1, 8))),
510+
]
511+
512+
mi = ModelInputs(Model(), inputs)
513+
self.assertIn("DynamicCache", string_type(mi.inputs, with_shape=True))
514+
ds = mi.guess_dynamic_shapes(auto="dim")
515+
print(ds)
516+
self.assertEqual(
517+
ds,
518+
(
519+
(
520+
[
521+
[{}, {}],
522+
[
523+
{0: "dim_0I_1o_0l0", 2: "dim_0I_1o_0l2", 3: "dim_0I_1o_0l3"},
524+
{0: "dim_0I_1o_1l0", 2: "dim_0I_1o_1l2", 3: "dim_0I_1o_1l3"},
525+
],
526+
],
527+
{3: "dim_1I3"},
528+
),
529+
{},
530+
),
531+
)
532+
473533
def test_couple_input_ds_0(self):
474534
T3x4 = torch.rand((3, 4))
475535
T3x1 = torch.rand((3, 1))

0 commit comments

Comments
 (0)