Skip to content

Commit 3ab855c

Browse files
committed
fix dynamic shapes
1 parent e16edae commit 3ab855c

File tree

7 files changed

+229
-8
lines changed

7 files changed

+229
-8
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.7.4
55
+++++
66

7-
* :pr:`174`: changes for the next version of onnx
7+
* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
88

99
0.7.3
1010
+++++

_doc/index.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ onnx-diagnostic: investigate onnx models
2121
The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
2222
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
2323
Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-diagnostic/>`_.
24-
Patches can be enabled as follows:
24+
Patches can be enabled as follows with function
25+
:func:`onnx_diagnostic.torch_export_patches.torch_export_patches`:
2526

2627
.. code-block:: python
2728
@@ -31,7 +32,8 @@ Patches can be enabled as follows:
3132
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
3233
# ...
3334
34-
Dynamic shapes are difficult to guess for caches, one function
35+
Dynamic shapes are difficult to guess for caches, function
36+
:func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
3537
returns a structure defining all dimensions as dynamic.
3638
You need then to remove those which are not dynamic in your model.
3739

_scripts/test_backend_onnxruntime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def supports_device(cls, device: str) -> bool:
5656
d = Device(device)
5757
if d == DeviceType.CPU:
5858
return True
59-
if d == DeviceType.GPU:
59+
if d == DeviceType.CUDA:
6060
import torch
6161

6262
return torch.cuda.is_available()
@@ -65,7 +65,7 @@ def supports_device(cls, device: str) -> bool:
6565
@classmethod
6666
def create_inference_session(cls, model, device):
6767
d = Device(device)
68-
if d == DeviceType.GPU:
68+
if d == DeviceType.CUDA:
6969
providers = ["CUDAExecutionProvider"]
7070
elif d == DeviceType.CPU:
7171
providers = ["CPUExecutionProvider"]

_unittests/ut_export/test_shape_helper.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,148 @@
55
all_dynamic_shape_from_inputs,
66
guess_dynamic_shapes_from_inputs,
77
)
8+
from onnx_diagnostic.helpers.cache_helper import (
9+
make_dynamic_cache,
10+
make_sliding_window_cache,
11+
make_encoder_decoder_cache,
12+
make_static_cache,
13+
make_mamba_cache,
14+
)
815
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
16+
from onnx_diagnostic.torch_export_patches import torch_export_patches
917

1018

1119
class TestShapeHelper(ExtTestCase):
20+
21+
@requires_transformers("4.52")
22+
@requires_torch("2.7.99")
23+
def test_all_dynamic_shape_from_cache(self):
24+
cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))])
25+
ds = all_dynamic_shape_from_inputs(cache)
26+
self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds)
27+
28+
@requires_torch("2.7.99")
29+
def test_all_dynamic_shape_all_transformers_cache(self):
30+
caches = [
31+
(
32+
make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]),
33+
[[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]],
34+
),
35+
(
36+
make_encoder_decoder_cache(
37+
make_dynamic_cache(
38+
[
39+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
40+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
41+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
42+
]
43+
),
44+
make_dynamic_cache(
45+
[
46+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
47+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
48+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
49+
]
50+
),
51+
),
52+
[
53+
[
54+
[
55+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2"},
56+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2"},
57+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2"},
58+
],
59+
[
60+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2"},
61+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2"},
62+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2"},
63+
],
64+
],
65+
[
66+
[
67+
{0: "d_6_0", 1: "d_6_1", 2: "d_6_2"},
68+
{0: "d_7_0", 1: "d_7_1", 2: "d_7_2"},
69+
{0: "d_8_0", 1: "d_8_1", 2: "d_8_2"},
70+
],
71+
[
72+
{0: "d_9_0", 1: "d_9_1", 2: "d_9_2"},
73+
{0: "d_10_0", 1: "d_10_1", 2: "d_10_2"},
74+
{0: "d_11_0", 1: "d_11_1", 2: "d_11_2"},
75+
],
76+
],
77+
],
78+
),
79+
(
80+
make_sliding_window_cache(
81+
[
82+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
83+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
84+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
85+
]
86+
),
87+
[
88+
[
89+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
90+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"},
91+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"},
92+
],
93+
[
94+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
95+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
96+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"},
97+
],
98+
],
99+
),
100+
(
101+
make_static_cache(
102+
[
103+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
104+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
105+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
106+
],
107+
max_cache_len=15,
108+
),
109+
[
110+
[
111+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
112+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"},
113+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"},
114+
],
115+
[
116+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
117+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
118+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"},
119+
],
120+
],
121+
),
122+
(
123+
make_mamba_cache(
124+
[
125+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
126+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
127+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
128+
]
129+
),
130+
[
131+
[
132+
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2"},
133+
{0: "d_1_0", 1: "d_1_1", 2: "d_1_2"},
134+
{0: "d_2_0", 1: "d_2_1", 2: "d_2_2"},
135+
],
136+
[
137+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2"},
138+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2"},
139+
{0: "d_5_0", 1: "d_5_1", 2: "d_5_2"},
140+
],
141+
],
142+
),
143+
]
144+
with torch_export_patches(patch_transformers=True):
145+
for cache, exds in caches:
146+
with self.subTest(cache=type(cache)):
147+
ds = all_dynamic_shape_from_inputs(cache)
148+
self.assertEqual(exds, ds)
149+
12150
@requires_transformers("4.52")
13151
@requires_torch("2.7.99")
14152
def test_all_dynamic_shape_from_inputs(self):

_unittests/ut_reference/test_backend_onnxruntime_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def supports_device(cls, device: str) -> bool:
5252
d = Device(device)
5353
if d == DeviceType.CPU:
5454
return True
55-
if d == DeviceType.GPU:
55+
if d == DeviceType.CUDA:
5656
import torch
5757

5858
return torch.cuda.is_available()
@@ -61,7 +61,7 @@ def supports_device(cls, device: str) -> bool:
6161
@classmethod
6262
def create_inference_session(cls, model, device):
6363
d = Device(device)
64-
if d == DeviceType.GPU:
64+
if d == DeviceType.CUDA:
6565
providers = ["CUDAExecutionProvider"]
6666
elif d == DeviceType.CPU:
6767
providers = ["CPUExecutionProvider"]

onnx_diagnostic/export/shape_helper.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,77 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
3030
)
3131
ds = all_dynamic_shape_from_inputs(inputs)
3232
pprint.pprint(ds)
33+
34+
For this function to work, patches must be enabled if :epkg:`transformers`
35+
does not implement the serialization functions.
36+
37+
.. runpython::
38+
:showcode:
39+
40+
import pprint
41+
import torch
42+
from onnx_diagnostic.helpers.cache_helper import (
43+
make_dynamic_cache,
44+
make_encoder_decoder_cache,
45+
make_mamba_cache,
46+
make_sliding_window_cache,
47+
make_static_cache,
48+
)
49+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
50+
from onnx_diagnostic.torch_export_patches import torch_export_patches
51+
52+
caches = [
53+
make_dynamic_cache(
54+
[
55+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
56+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
57+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
58+
]
59+
),
60+
make_encoder_decoder_cache(
61+
make_dynamic_cache(
62+
[
63+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
64+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
65+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
66+
]
67+
),
68+
make_dynamic_cache(
69+
[
70+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
71+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
72+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
73+
]
74+
),
75+
),
76+
make_sliding_window_cache(
77+
[
78+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
79+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
80+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
81+
]
82+
),
83+
make_static_cache(
84+
[
85+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
86+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
87+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
88+
],
89+
max_cache_len=15,
90+
),
91+
make_mamba_cache(
92+
[
93+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
94+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
95+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
96+
]
97+
),
98+
]
99+
100+
with torch_export_patches(patch_transformers=True):
101+
for cache in caches:
102+
print(f"-- {cache.__class__.__name__}")
103+
pprint.pprint(all_dynamic_shape_from_inputs(cache))
33104
"""
34105
if isinstance(dim_prefix, str):
35106
prefixes: Set[str] = set()

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,21 @@ def flatten_unflatten_for_dynamic_shapes(
3939
subtrees.append(value)
4040
start = end
4141
if use_dict:
42-
if spec.type is dict or spec.context:
42+
if spec.type is dict:
4343
# This a dictionary.
4444
return dict(zip(spec.context, subtrees))
4545
if spec.type is tuple:
4646
return tuple(subtrees)
47+
if spec.type is list:
48+
return list(subtrees)
49+
if spec.context:
50+
# This is a custom class with attributes.
51+
# It is returned as a list.
52+
return list(subtrees)
53+
raise ValueError(
54+
f"Unable to interpret spec type {spec.type} "
55+
f"(type is {type(spec.type)}, context is {spec.context})."
56+
)
4757
# This is a list.
4858
return subtrees
4959

0 commit comments

Comments
 (0)