Skip to content

Commit edec507

Browse files
authored
Using string instead of Dim (#58)
* Using string instead of Dim * doc * improvment * improves dynamic shapes handling * changes * fix things * f * disabl
1 parent 1eab135 commit edec507

File tree

19 files changed

+488
-224
lines changed

19 files changed

+488
-224
lines changed

CHANGELOGS.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,
8+
use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes
9+
for a specific models, it is a valid definition for ``torch.onnx.export``
10+
which can reuse the names
711
* :pr:`55`: add support for text-classification
812
* :pr:`54`: add support for fill-mask, refactoring
913
* :pr:`52`: add support for zero-shot-image-classification

_doc/examples/plot_export_hub_codellama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
task_from_id,
3131
)
3232
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
33+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3334

3435
model_id = "codellama/CodeLlama-7b-Python-hf"
3536
print("info", get_model_info(model_id))
@@ -96,7 +97,7 @@
9697
model,
9798
(),
9899
kwargs=f(data["inputs"]),
99-
dynamic_shapes=data["dynamic_shapes"],
100+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
100101
strict=False,
101102
)
102103
print(ep)

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
2828
from onnx_diagnostic.helpers.rt_helper import make_feeds
2929
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
30+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3031
from onnx_diagnostic.torch_models.hghub import (
3132
get_untrained_model_with_inputs,
3233
)
@@ -92,7 +93,7 @@
9293
untrained_model,
9394
(),
9495
kwargs=modificator(copy.deepcopy(inputs)),
95-
dynamic_shapes=dynamic_shapes,
96+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
9697
strict=False, # mandatory for torch==2.6
9798
)
9899

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ onnx-diagnostic: investigate onnx models
2020

2121
The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
2222
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
23+
Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-diagnostic/>`_.
2324

2425
.. code-block:: python
2526

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -576,17 +576,20 @@ def test_couple_input_ds_cache(self):
576576
Cls(
577577
(),
578578
kwargs,
579-
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
579+
{
580+
"A": ds_batch,
581+
"B": (ds_batch, [[ds_batch, ds_batch], [ds_batch, ds_batch]]),
582+
},
580583
).invalid_dimensions_for_export(),
581584
)
582585
self.assertEqual(
583-
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
586+
{"B": (None, [[None, {2: "d=[1]"}], [None, {2: "d=[1]"}]])},
584587
Cls(
585588
(),
586589
kwargs,
587590
{
588591
"A": ds_batch,
589-
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
592+
"B": (ds_batch, [[ds_batch, ds_batch_seq], [ds_batch, ds_batch_seq]]),
590593
},
591594
).invalid_dimensions_for_export(),
592595
)
@@ -762,10 +765,8 @@ def test_dynamic_cache_replace_by_string(self):
762765
self.assertEqual(
763766
{
764767
"cache": [
765-
{0: "Dim0", 1: "Dim1"},
766-
{0: "Dim2", 1: "Dim3"},
767-
{0: "Dim4", 1: "Dim5"},
768-
{0: "Dim6", 1: "Dim7"},
768+
[{0: "Dim0", 1: "Dim1"}, {0: "Dim2", 1: "Dim3"}],
769+
[{0: "Dim4", 1: "Dim5"}, {0: "Dim6", 1: "Dim7"}],
769770
]
770771
},
771772
as_string,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6+
from onnx_diagnostic.export import CoupleInputsDynamicShapes
7+
from onnx_diagnostic.torch_export_patches.patch_inputs import (
8+
convert_dynamic_axes_into_dynamic_shapes,
9+
)
10+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
11+
12+
13+
class TestCacheHelpers(ExtTestCase):
14+
def test_string_type(self):
15+
DYN = torch.export.Dim.DYNAMIC
16+
self.assertEqual("DYNAMIC", string_type(DYN, verbose=0))
17+
AUTO = torch.export.Dim.AUTO
18+
self.assertEqual("AUTO", string_type(AUTO, verbose=0))
19+
self.assertEqual("#1[DYNAMIC]", string_type([DYN]))
20+
21+
batch = torch.export.Dim("batch")
22+
dynamic_shapes = dict(
23+
input_ids={0: batch, 1: "seq"},
24+
attention_mask={0: batch, 1: "seq"},
25+
position_ids={0: batch, 1: "seq"},
26+
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
27+
)
28+
self.assertEqual(
29+
"dict(input_ids:{0:Dim(batch),1:DYN(seq)},"
30+
"attention_mask:{0:Dim(batch),1:DYN(seq)},"
31+
"position_ids:{0:Dim(batch),1:DYN(seq)},"
32+
"past_key_values:#2[#1[{0:Dim(batch),2:DYN(seq)}],"
33+
"#1[{0:Dim(batch),2:DYN(seq)}]])",
34+
string_type(dynamic_shapes),
35+
)
36+
37+
def test_replace_by(self):
38+
bsize, nheads, slen, dim = 2, 4, 3, 7
39+
40+
past_key_values = make_dynamic_cache(
41+
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
42+
)
43+
kwargs = dict(
44+
input_ids=torch.zeros(2, 3),
45+
attention_mask=torch.zeros(2, 3),
46+
position_ids=torch.zeros(2, 3),
47+
past_key_values=past_key_values,
48+
)
49+
batch = torch.export.Dim("batch")
50+
dynamic_shapes = dict(
51+
input_ids={0: batch, 1: "seq"},
52+
attention_mask={0: batch, 1: "seq"},
53+
position_ids={0: batch, 1: "seq"},
54+
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
55+
)
56+
57+
DYN = torch.export.Dim.DYNAMIC
58+
nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
59+
None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes
60+
)
61+
self.assertEqual(dynamic_shapes, nds)
62+
63+
with bypass_export_some_errors(patch_transformers=True):
64+
cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes)
65+
res = cpl.replace_string_by()
66+
dsc = res["past_key_values"]
67+
self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc)
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_inputs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.torch_export_patches.patch_inputs import (
77
convert_dynamic_axes_into_dynamic_shapes,
8+
use_dyn_not_str,
89
)
910

1011

@@ -111,6 +112,26 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
111112
string_type(res[1], with_shape=True),
112113
)
113114

115+
def test_use_dyn_not_str(self):
116+
batch = torch.export.Dim("batch")
117+
dynamic_shapes = dict(
118+
input_ids={0: batch, 1: "seq"},
119+
attention_mask={0: batch, 1: "seq"},
120+
position_ids={0: batch, 1: "seq"},
121+
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
122+
)
123+
res = use_dyn_not_str(dynamic_shapes)
124+
DYN = torch.export.Dim.DYNAMIC
125+
self.assertEqual(
126+
dict(
127+
input_ids={0: batch, 1: DYN},
128+
attention_mask={0: batch, 1: DYN},
129+
position_ids={0: batch, 1: DYN},
130+
past_key_values=[[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]],
131+
),
132+
res,
133+
)
134+
114135

115136
if __name__ == "__main__":
116137
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import unittest
33
import packaging.version as pv
44
import torch
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
5+
from onnx_diagnostic.ext_test_case import (
6+
ExtTestCase,
7+
hide_stdout,
8+
ignore_warnings,
9+
requires_torch,
10+
)
611
from onnx_diagnostic.torch_models.test_helper import (
712
get_inputs_for_task,
813
validate_model,
@@ -54,6 +59,7 @@ def test_validate_model_export(self):
5459
self.assertIsInstance(summary, dict)
5560
self.assertIsInstance(data, dict)
5661

62+
@requires_torch("2.7")
5763
@hide_stdout()
5864
@ignore_warnings(FutureWarning)
5965
def test_validate_model_onnx(self):

onnx_diagnostic/_command_lines_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def get_parser_validate() -> ArgumentParser:
302302

303303
def _cmd_validate(argv: List[Any]):
304304
from .helpers import string_type
305-
from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean
305+
from .torch_models.test_helper import get_inputs_for_task, validate_model
306306
from .tasks import supported_tasks
307307

308308
parser = get_parser_validate()
@@ -320,7 +320,7 @@ def _cmd_validate(argv: List[Any]):
320320
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
321321
print("-- dynamic_shapes")
322322
for k, v in data["dynamic_shapes"].items():
323-
print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
323+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
324324
else:
325325
# Let's skip any invalid combination if known to be unsupported
326326
if (

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def replace_string_by(self, value: Any = None):
9292
return self._generic_walker(
9393
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
9494
inputs, ds, value=value
95-
)
95+
),
96+
flatten_unflatten=True,
9697
)
9798

9899
@classmethod
@@ -135,7 +136,8 @@ def replace_by_string(self):
135136
return self._generic_walker(
136137
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
137138
inputs, ds, unique=unique
138-
)
139+
),
140+
flatten_unflatten=True,
139141
)
140142

141143
@classmethod
@@ -203,7 +205,7 @@ def invalid_dimensions_for_export(self):
203205
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
204206
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
205207
"""
206-
return self._generic_walker(self._valid_shapes_tensor)
208+
return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
207209

208210
@classmethod
209211
def _valid_shapes_tensor(cls, inputs, ds):
@@ -221,7 +223,9 @@ def _valid_shapes_tensor(cls, inputs, ds):
221223
issues[i] = f"d=[{d}]"
222224
return issues if issues else None
223225

224-
def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
226+
def _generic_walker(
227+
self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
228+
):
225229
"""
226230
Generic deserializator walking through inputs and dynamic_shapes all along.
227231
The function returns a result with the same structure as the dynamic shapes.
@@ -231,15 +235,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
231235
f"Type mismatch, args={string_type(self.args)} and "
232236
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
233237
)
234-
res = self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
238+
res = self._generic_walker_step(
239+
processor,
240+
self.kwargs,
241+
self.dynamic_shapes,
242+
flatten_unflatten=flatten_unflatten,
243+
)
235244
return (tuple(), res) if args_kwargs else res
236245

237246
if not self.kwargs:
238247
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
239248
f"Type mismatch, args={string_type(self.args)} and "
240249
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
241250
)
242-
res = self._generic_walker_step(processor, self.args, self.dynamic_shapes)
251+
res = self._generic_walker_step(
252+
processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
253+
)
243254
return (res, {}) if args_kwargs else res
244255

245256
assert isinstance(self.dynamic_shapes, dict), (
@@ -250,12 +261,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
250261
self.dynamic_shapes
251262
):
252263
# No dynamic shapes for the positional arguments.
253-
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
264+
return self._generic_walker_step(
265+
processor,
266+
self.kwargs,
267+
self.dynamic_shapes,
268+
flatten_unflatten=flatten_unflatten,
269+
)
254270

255271
if isinstance(self.args_names, list):
256272
if not set(self.args_names) & set(self.dynamic_shapes):
257273
# No dynamic shapes for the positional arguments.
258-
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
274+
return self._generic_walker_step(
275+
processor,
276+
self.kwargs,
277+
self.dynamic_shapes,
278+
flatten_unflatten=flatten_unflatten,
279+
)
259280

260281
assert self.args_names, (
261282
"args and kwargs are filled, then args_names must be specified in "
@@ -268,7 +289,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
268289
)
269290
kwargs = dict(zip(self.args_names, self.args))
270291
kwargs.update(self.kwargs)
271-
res = self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
292+
res = self._generic_walker_step(
293+
processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
294+
)
272295
if args_kwargs:
273296
pgs = [None for _ in range(len(self.args))]
274297
kws = {}
@@ -286,7 +309,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
286309
)
287310

288311
@classmethod
289-
def _generic_walker_step(cls, processor: Callable, inputs, ds):
312+
def _generic_walker_step(
313+
cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
314+
):
290315
if isinstance(inputs, torch.Tensor):
291316
return processor(inputs, ds)
292317
if isinstance(inputs, (int, float, str)):
@@ -303,7 +328,11 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
303328
if isinstance(inputs, (tuple, list)):
304329
value = []
305330
for i, d in zip(inputs, ds):
306-
value.append(cls._generic_walker_step(processor, i, d))
331+
value.append(
332+
cls._generic_walker_step(
333+
processor, i, d, flatten_unflatten=flatten_unflatten
334+
)
335+
)
307336
return (
308337
(value if isinstance(ds, list) else tuple(value))
309338
if any(v is not None for v in value)
@@ -314,7 +343,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
314343
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
315344
dvalue = {}
316345
for k, v in inputs.items():
317-
t = cls._generic_walker_step(processor, v, ds[k])
346+
t = cls._generic_walker_step(
347+
processor, v, ds[k], flatten_unflatten=flatten_unflatten
348+
)
318349
if t is not None:
319350
dvalue[k] = t
320351
return dvalue if dvalue else None
@@ -325,11 +356,18 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
325356
f"torch.utils._pytree.register_pytree_node, it is not possible to "
326357
f"map this class with the given dynamic shapes."
327358
)
359+
if flatten_unflatten:
360+
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
361+
return cls._generic_walker_step(
362+
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
363+
)
328364
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
329365
if all(isinstance(t, torch.Tensor) for t in flat):
330366
# We need to flatten dynamic shapes as well
331367
ds = flatten_dynamic_shapes(ds)
332-
return cls._generic_walker_step(processor, flat, ds)
368+
return cls._generic_walker_step(
369+
processor, flat, ds, flatten_unflatten=flatten_unflatten
370+
)
333371

334372
class ChangeDimensionProcessor:
335373
def __init__(self, desired_values):

0 commit comments

Comments
 (0)