Skip to content

Commit 68d71cf

Browse files
authored
Makes patches work with FakeTensor (#272)
* fix patches * fix patch * fixes shape information * fix issues * add fake * lint * more about fake helper * g * mypy * pypi * fix shape * mypy
1 parent 78a024b commit 68d71cf

File tree

14 files changed

+453
-48
lines changed

14 files changed

+453
-48
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7+
* :pr:`272`: makes patches work with FakeTensor
78
* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
89
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
910
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
@@ -100,7 +101,7 @@ Change Logs
100101
+++++
101102

102103
* :pr:`178`: add a patch for eager_mask to handle ``assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs``
103-
* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
104+
* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shapes_from_inputs
104105

105106
0.7.3
106107
+++++

README.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ You need then to remove those which are not dynamic in your model.
4343

4444
.. code-block:: python
4545
46-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
46+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
4747
48-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
48+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
4949
5050
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
5151
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
@@ -109,13 +109,13 @@ Snapshot of usefuls tools
109109
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
110110
# ...
111111
112-
**all_dynamic_shape_from_inputs**
112+
**all_dynamic_shapes_from_inputs**
113113

114114
.. code-block:: python
115115
116-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
116+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
117117
118-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
118+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
119119
120120
**torch_export_rewrite**
121121

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.fake_tensor_helper
3+
==========================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.fake_tensor_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ onnx_diagnostic.helpers
1111
cache_helper
1212
config_helper
1313
doc_helper
14+
fake_tensor_helper
1415
graph_helper
1516
helper
1617
_log_helper

_doc/index.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ Patches can be enabled as follows with function
3333
# ...
3434
3535
Dynamic shapes are difficult to guess for caches, function
36-
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
36+
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
3737
returns a structure defining all dimensions as dynamic.
3838
You need then to remove those which are not dynamic in your model.
3939

4040
.. code-block:: python
4141
42-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
42+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
4343
44-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
44+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
4545
4646
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
4747
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`.
@@ -134,16 +134,16 @@ See :func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`.
134134
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
135135
# ...
136136
137-
all_dynamic_shape_from_inputs
137+
all_dynamic_shapes_from_inputs
138138
+++++++++++++++++++++++++++++
139139

140-
See :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`.
140+
See :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`.
141141

142142
.. code-block:: python
143143
144-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
144+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
145145
146-
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
146+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
147147
148148
string_type
149149
+++++++++++

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from onnx_diagnostic import doc
2222
from onnx_diagnostic.helpers import string_type
2323
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
24-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
24+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
2525

2626
bsize, nheads, slen, dim = 2, 1, 30, 96
2727

@@ -39,9 +39,9 @@
3939
print(string_type(inputs, with_shape=True))
4040

4141
# %%
42-
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
42+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
4343
# produces the corresponding dynamic shapes assuming they are all dynamic.
44-
ds = all_dynamic_shape_from_inputs(inputs)
44+
ds = all_dynamic_shapes_from_inputs(inputs)
4545
pprint.pprint(ds)
4646

4747
# %%

_doc/recipes/plot_dynamic_shapes_what.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from onnx_diagnostic import doc
1717
from onnx_diagnostic.helpers import string_type
1818
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
19+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
2020
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2121
from onnx_diagnostic.torch_export_patches import torch_export_patches
2222

@@ -34,9 +34,9 @@
3434
print(string_type(inputs, with_shape=True))
3535

3636
# %%
37-
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
37+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
3838
# produces the corresponding dynamic shapes assuming they are all dynamic.
39-
ds = all_dynamic_shape_from_inputs(inputs)
39+
ds = all_dynamic_shapes_from_inputs(inputs)
4040
pprint.pprint(ds)
4141

4242
# %%
@@ -56,13 +56,13 @@
5656

5757
# %%
5858
# And the input shapes.
59-
ds = all_dynamic_shape_from_inputs(inputs)
59+
ds = all_dynamic_shapes_from_inputs(inputs)
6060
if ds["past_key_values"]:
6161
print("transformers implemented serialization function for StaticCache.")
6262
else:
6363
print("We need to use serialization function implemented in this package.")
6464
with torch_export_patches(patch_transformers=True):
65-
ds = all_dynamic_shape_from_inputs(inputs)
65+
ds = all_dynamic_shapes_from_inputs(inputs)
6666

6767
# %%
6868
# That gives.

_unittests/ut_export/test_shape_helper.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
4-
from onnx_diagnostic.export.shape_helper import (
5-
all_dynamic_shape_from_inputs,
6-
guess_dynamic_shapes_from_inputs,
7-
)
4+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
5+
from onnx_diagnostic.torch_export_patches import torch_export_patches
6+
from onnx_diagnostic.helpers import flatten_object
87
from onnx_diagnostic.helpers.cache_helper import (
98
make_dynamic_cache,
109
make_sliding_window_cache,
1110
make_encoder_decoder_cache,
1211
make_static_cache,
1312
)
14-
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
15-
from onnx_diagnostic.torch_export_patches import torch_export_patches
13+
from onnx_diagnostic.export.shape_helper import (
14+
all_dynamic_shapes_from_inputs,
15+
guess_dynamic_shapes_from_inputs,
16+
make_fake_with_dynamic_dimensions,
17+
)
1618

1719

1820
class TestShapeHelper(ExtTestCase):
19-
2021
@requires_transformers("4.52")
2122
@requires_torch("2.7.99")
2223
def test_all_dynamic_shape_from_cache(self):
2324
cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))])
24-
ds = all_dynamic_shape_from_inputs(cache)
25+
ds = all_dynamic_shapes_from_inputs(cache)
2526
self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds)
2627

2728
@requires_torch("2.7.99")
@@ -122,17 +123,17 @@ def test_all_dynamic_shape_all_transformers_cache(self):
122123
with torch_export_patches(patch_transformers=True):
123124
for cache, exds in caches:
124125
with self.subTest(cache_name=cache.__class__.__name__):
125-
ds = all_dynamic_shape_from_inputs(cache)
126+
ds = all_dynamic_shapes_from_inputs(cache)
126127
self.assertEqual(exds, ds)
127128

128129
@requires_transformers("4.52")
129130
@requires_torch("2.7.99")
130-
def test_all_dynamic_shape_from_inputs(self):
131-
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
131+
def test_all_dynamic_shapes_from_inputs(self):
132+
ds = all_dynamic_shapes_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
132133
self.assertEqual(({0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}), ds)
133-
ds = all_dynamic_shape_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
134+
ds = all_dynamic_shapes_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
134135
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
135-
ds = all_dynamic_shape_from_inputs(
136+
ds = all_dynamic_shapes_from_inputs(
136137
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
137138
)
138139
self.assertEqual(
@@ -145,9 +146,9 @@ def test_all_dynamic_shape_from_inputs(self):
145146

146147
@requires_transformers("4.52")
147148
@requires_torch("2.7.99")
148-
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
149+
def test_all_dynamic_shapes_from_inputs_dynamic_cache(self):
149150
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
150-
ds = all_dynamic_shape_from_inputs(data["inputs"])
151+
ds = all_dynamic_shapes_from_inputs(data["inputs"])
151152
self.assertEqual(
152153
{
153154
"input_ids": {0: "d_0_0", 1: "d_0_1"},
@@ -184,6 +185,60 @@ def test_guess_dynamic_shapes_from_inputs(self):
184185
guessed,
185186
)
186187

188+
@requires_transformers("4.55")
189+
@requires_torch("2.9")
190+
def test_make_fake_with_dynamic_dimensions_tensor(self):
191+
res = make_fake_with_dynamic_dimensions(
192+
(torch.rand((2, 32, 30, 96), dtype=torch.float16),),
193+
({0: "batch", 2: "cache_length"},),
194+
)
195+
reshaped = res[0][0]
196+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
197+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
198+
self.assertEqual(reshaped.shape[1], 32)
199+
self.assertEqual(reshaped.shape[3], 96)
200+
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
201+
202+
@requires_transformers("4.55")
203+
@requires_torch("2.9")
204+
def test_make_fake_with_dynamic_dimensions_whole(self):
205+
res = make_fake_with_dynamic_dimensions(
206+
dict(
207+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
208+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
209+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
210+
past_key_values=make_dynamic_cache(
211+
[
212+
(
213+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
214+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
215+
),
216+
(
217+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
218+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
219+
),
220+
]
221+
),
222+
),
223+
dynamic_shapes={
224+
"input_ids": {0: "batch", 1: "seq_length"},
225+
"attention_mask": {0: "batch", 1: "cache+seq"},
226+
"position_ids": {0: "batch", 1: "seq_length"},
227+
"past_key_values": [
228+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
229+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
230+
],
231+
},
232+
)
233+
flat = flatten_object(res[0], drop_keys=True)
234+
for t in flat:
235+
if len(t.shape) == 4:
236+
self.assertIsInstance(t.shape[0], torch.SymInt)
237+
self.assertIsInstance(t.shape[2], torch.SymInt)
238+
self.assertEqual(t.shape[1], 32)
239+
self.assertEqual(t.shape[3], 96)
240+
self.assertNotEqual(t.shape[0], t.shape[2])
241+
187242

188243
if __name__ == "__main__":
189244
unittest.main(verbosity=2)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
4+
from onnx_diagnostic.helpers import flatten_object
5+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, fake_reshape
7+
8+
9+
class TestMakeTensorHelper(ExtTestCase):
10+
11+
def test_fake_reshape_generic(self):
12+
t = torch.zeros((2, 3, 4, 5), dtype=torch.float32)
13+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
14+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
15+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
16+
self.assertEqual(reshaped.shape[1], 3)
17+
self.assertEqual(reshaped.shape[3], 5)
18+
19+
def test_fake_reshape_dim_1(self):
20+
t = torch.zeros((1, 3, 4, 5), dtype=torch.float32)
21+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
22+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
23+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
24+
self.assertEqual(reshaped.shape[1], 3)
25+
self.assertEqual(reshaped.shape[3], 5)
26+
27+
def test_fake_reshape_dim_0(self):
28+
t = torch.zeros((0, 3, 4, 5), dtype=torch.float32)
29+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
30+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
31+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
32+
self.assertEqual(reshaped.shape[1], 3)
33+
self.assertEqual(reshaped.shape[3], 5)
34+
35+
def test_fake_reshape_different(self):
36+
t = torch.zeros((2, 3, 2, 5), dtype=torch.float32)
37+
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
38+
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
39+
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
40+
self.assertEqual(reshaped.shape[1], 3)
41+
self.assertEqual(reshaped.shape[3], 5)
42+
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
43+
44+
@requires_transformers("4.55")
45+
def test_fake_inputs(self):
46+
inputs, _ = make_fake(
47+
dict(
48+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
49+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
50+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
51+
past_key_values=make_dynamic_cache(
52+
[
53+
(
54+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
55+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
56+
),
57+
(
58+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
59+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
60+
),
61+
]
62+
),
63+
)
64+
)
65+
flat = flatten_object(inputs, drop_keys=True)
66+
for t in flat:
67+
self.assertIsInstance(t, torch.Tensor)
68+
assert all(
69+
isinstance(s, torch.SymInt) for s in t.shape
70+
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"
71+
72+
73+
if __name__ == "__main__":
74+
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def main(argv: Optional[List[Any]] = None):
11281128
raise ValueError(
11291129
f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
11301130
)
1131-
parser = parsers[cmd]()
1131+
parser = parsers[cmd]() # type: ignore[operator]
11321132
parser.parse_args(argv[1:])
11331133
raise RuntimeError("The programme should have exited before.")
11341134

0 commit comments

Comments
 (0)