Skip to content

Commit 48b5662

Browse files
committed
more about fake helper
1 parent 02b212d commit 48b5662

File tree

10 files changed

+208
-97
lines changed

10 files changed

+208
-97
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7-
* :pr:`272`: makes patches woth with FakeTensor
7+
* :pr:`272`: makes patches work with FakeTensor
88
* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
99
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
1010
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
@@ -101,7 +101,7 @@ Change Logs
101101
+++++
102102

103103
* :pr:`178`: add a patch for eager_mask to handle ``assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs``
104-
* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
104+
* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shapes_from_inputs
105105

106106
0.7.3
107107
+++++

README.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ You need then to remove those which are not dynamic in your model.
4343

4444
.. code-block:: python
4545
46-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
46+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
4747
48-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
48+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
4949
5050
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
5151
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
@@ -109,13 +109,13 @@ Snapshot of usefuls tools
109109
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
110110
# ...
111111
112-
**all_dynamic_shape_from_inputs**
112+
**all_dynamic_shapes_from_inputs**
113113

114114
.. code-block:: python
115115
116-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
116+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
117117
118-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
118+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
119119
120120
**torch_export_rewrite**
121121

_doc/index.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ Patches can be enabled as follows with function
3333
# ...
3434
3535
Dynamic shapes are difficult to guess for caches, function
36-
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
36+
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
3737
returns a structure defining all dimensions as dynamic.
3838
You need then to remove those which are not dynamic in your model.
3939

4040
.. code-block:: python
4141
42-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
42+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
4343
44-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
44+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
4545
4646
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
4747
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.
@@ -134,16 +134,16 @@ See :func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`.
134134
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
135135
# ...
136136
137-
all_dynamic_shape_from_inputs
137+
all_dynamic_shapes_from_inputs
138138
+++++++++++++++++++++++++++++
139139

140-
See :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`.
140+
See :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`.
141141

142142
.. code-block:: python
143143
144-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
144+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
145145
146-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
146+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
147147
148148
string_type
149149
+++++++++++

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from onnx_diagnostic import doc
2222
from onnx_diagnostic.helpers import string_type
2323
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
24-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
24+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
2525

2626
bsize, nheads, slen, dim = 2, 1, 30, 96
2727

@@ -39,9 +39,9 @@
3939
print(string_type(inputs, with_shape=True))
4040

4141
# %%
42-
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
42+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
4343
# produces the corresponding dynamic shapes assuming they are all dynamic.
44-
ds = all_dynamic_shape_from_inputs(inputs)
44+
ds = all_dynamic_shapes_from_inputs(inputs)
4545
pprint.pprint(ds)
4646

4747
# %%

_doc/recipes/plot_dynamic_shapes_what.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from onnx_diagnostic import doc
1717
from onnx_diagnostic.helpers import string_type
1818
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
19+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
2020
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2121
from onnx_diagnostic.torch_export_patches import torch_export_patches
2222

@@ -34,9 +34,9 @@
3434
print(string_type(inputs, with_shape=True))
3535

3636
# %%
37-
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
37+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
3838
# produces the corresponding dynamic shapes assuming they are all dynamic.
39-
ds = all_dynamic_shape_from_inputs(inputs)
39+
ds = all_dynamic_shapes_from_inputs(inputs)
4040
pprint.pprint(ds)
4141

4242
# %%
@@ -56,13 +56,13 @@
5656

5757
# %%
5858
# And the input shapes.
59-
ds = all_dynamic_shape_from_inputs(inputs)
59+
ds = all_dynamic_shapes_from_inputs(inputs)
6060
if ds["past_key_values"]:
6161
print("transformers implemented serialization function for StaticCache.")
6262
else:
6363
print("We need to use serialization function implemented in this package.")
6464
with torch_export_patches(patch_transformers=True):
65-
ds = all_dynamic_shape_from_inputs(inputs)
65+
ds = all_dynamic_shapes_from_inputs(inputs)
6666

6767
# %%
6868
# That gives.

_unittests/ut_export/test_shape_helper.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
4-
from onnx_diagnostic.export.shape_helper import (
5-
all_dynamic_shape_from_inputs,
6-
guess_dynamic_shapes_from_inputs,
7-
)
4+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
5+
from onnx_diagnostic.torch_export_patches import torch_export_patches
86
from onnx_diagnostic.helpers.cache_helper import (
97
make_dynamic_cache,
108
make_sliding_window_cache,
119
make_encoder_decoder_cache,
1210
make_static_cache,
1311
)
14-
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
15-
from onnx_diagnostic.torch_export_patches import torch_export_patches
12+
from onnx_diagnostic.export.shape_helper import (
13+
all_dynamic_shapes_from_inputs,
14+
guess_dynamic_shapes_from_inputs,
15+
)
1616

1717

1818
class TestShapeHelper(ExtTestCase):
@@ -21,7 +21,7 @@ class TestShapeHelper(ExtTestCase):
2121
@requires_torch("2.7.99")
2222
def test_all_dynamic_shape_from_cache(self):
2323
cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))])
24-
ds = all_dynamic_shape_from_inputs(cache)
24+
ds = all_dynamic_shapes_from_inputs(cache)
2525
self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds)
2626

2727
@requires_torch("2.7.99")
@@ -122,17 +122,17 @@ def test_all_dynamic_shape_all_transformers_cache(self):
122122
with torch_export_patches(patch_transformers=True):
123123
for cache, exds in caches:
124124
with self.subTest(cache_name=cache.__class__.__name__):
125-
ds = all_dynamic_shape_from_inputs(cache)
125+
ds = all_dynamic_shapes_from_inputs(cache)
126126
self.assertEqual(exds, ds)
127127

128128
@requires_transformers("4.52")
129129
@requires_torch("2.7.99")
130-
def test_all_dynamic_shape_from_inputs(self):
131-
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
130+
def test_all_dynamic_shapes_from_inputs(self):
131+
ds = all_dynamic_shapes_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
132132
self.assertEqual(({0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}), ds)
133-
ds = all_dynamic_shape_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
133+
ds = all_dynamic_shapes_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
134134
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
135-
ds = all_dynamic_shape_from_inputs(
135+
ds = all_dynamic_shapes_from_inputs(
136136
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
137137
)
138138
self.assertEqual(
@@ -145,9 +145,9 @@ def test_all_dynamic_shape_from_inputs(self):
145145

146146
@requires_transformers("4.52")
147147
@requires_torch("2.7.99")
148-
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
148+
def test_all_dynamic_shapes_from_inputs_dynamic_cache(self):
149149
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
150-
ds = all_dynamic_shape_from_inputs(data["inputs"])
150+
ds = all_dynamic_shapes_from_inputs(data["inputs"])
151151
self.assertEqual(
152152
{
153153
"input_ids": {0: "d_0_0", 1: "d_0_1"},

_unittests/ut_helpers/test_fake_tensor_helper.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,44 @@
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
44
from onnx_diagnostic.helpers import flatten_object
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6-
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
6+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, fake_reshape
77

88

99
class TestMakeTensorHelper(ExtTestCase):
10+
11+
def test_fake_reshape_generic(self):
12+
t = torch.zeros((2, 3, 4, 5), dtype=torch.float32)
13+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
14+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
15+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
16+
self.assertEqual(reshaped.shape[1], 3)
17+
self.assertEqual(reshaped.shape[3], 5)
18+
19+
def test_fake_reshape_dim_1(self):
20+
t = torch.zeros((1, 3, 4, 5), dtype=torch.float32)
21+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
22+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
23+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
24+
self.assertEqual(reshaped.shape[1], 3)
25+
self.assertEqual(reshaped.shape[3], 5)
26+
27+
def test_fake_reshape_dim_0(self):
28+
t = torch.zeros((0, 3, 4, 5), dtype=torch.float32)
29+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
30+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
31+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
32+
self.assertEqual(reshaped.shape[1], 3)
33+
self.assertEqual(reshaped.shape[3], 5)
34+
35+
def test_fake_reshape_different(self):
36+
t = torch.zeros((2, 3, 2, 5), dtype=torch.float32)
37+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
38+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
39+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
40+
self.assertEqual(reshaped.shape[1], 3)
41+
self.assertEqual(reshaped.shape[3], 5)
42+
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
43+
1044
@requires_transformers("4.55")
1145
def test_fake_inputs(self):
1246
inputs, _ = make_fake(

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@
22
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
33
import numpy as np
44
import torch
5-
from ..helpers import string_type, flatten_object
5+
from ..helpers import string_type
66
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
7-
from ..helpers.fake_tensor_helper import make_fake
87

98
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
109

1110

12-
def flatten_dynamic_shapes(ds: Any) -> Any:
11+
def _flatten_dynamic_shapes(ds: Any) -> Any:
1312
"""Flattens the dynamic shapes."""
1413
if isinstance(ds, list):
15-
return _flat_list([flatten_dynamic_shapes(t) for t in ds])
14+
return _flat_list([_flatten_dynamic_shapes(t) for t in ds])
1615
if isinstance(ds, tuple):
17-
return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
16+
return tuple(_flat_list([_flatten_dynamic_shapes(t) for t in ds]))
1817
if isinstance(ds, dict):
1918
if all(isinstance(i, int) for i in ds):
2019
# That's a dynamic shape
2120
return ds
22-
return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
21+
return _flat_list([_flatten_dynamic_shapes(t) for t in ds.values()])
2322
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
2423

2524

@@ -33,51 +32,6 @@ def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
3332
return res
3433

3534

36-
def make_fake_with_dynamic_dimensions(
37-
x: Optional[Any],
38-
dynamic_shapes: Any,
39-
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
40-
) -> Optional[Tuple["FakeTensor", "FaleTensorMode"]]: # noqa: F821
41-
"""
42-
Replaces all tensors by fake tensor respecting the same
43-
constraints as the following dynamic shapes.
44-
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
45-
46-
.. runpython::
47-
:showcode:
48-
49-
from onnx_diagnostic.export.dynamic_shapes import make_fake_with_dynamic_dimensions
50-
51-
inputs, _ = make_fake_with_dynamic_dimensions(
52-
dict(
53-
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
54-
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
55-
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
56-
past_key_values=make_dynamic_cache(
57-
[
58-
(
59-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
60-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
61-
),
62-
(
63-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
64-
torch.rand((2, 32, 30, 96), dtype=torch.float16),
65-
),
66-
]
67-
),
68-
)
69-
)
70-
print(inputs)
71-
"""
72-
fake_inputs = make_fake(x, fake_mode=fake_mode)
73-
flat_inputs = flatten_object(fake_inputs, drop_keys=True)
74-
flat_ds = flatten_dynamic_shapes(dynamic_shapes)
75-
assert len(flat_inputs) == len(flat_ds), (
76-
f"Mismatch between the number of input tensor {len(flat_inputs)} "
77-
f"and the number of dynamic_shapes {len(flat_ds)}"
78-
)
79-
80-
8135
class CoupleInputsDynamicShapes:
8236
"""
8337
Pair inputs / dynamic shapes.
@@ -426,7 +380,7 @@ def _generic_walker_step(
426380
flat, spec = torch.utils._pytree.tree_flatten(inputs)
427381
if all(isinstance(t, torch.Tensor) for t in flat):
428382
# We need to flatten dynamic shapes as well
429-
ds = flatten_dynamic_shapes(ds)
383+
ds = _flatten_dynamic_shapes(ds)
430384
res = cls._generic_walker_step(
431385
processor, flat, ds, flatten_unflatten=flatten_unflatten
432386
)

0 commit comments

Comments
 (0)