Skip to content

Commit cae7a79

Browse files
authored
Add function to produce full dynamic shapes (#152)
* Add function to produce full dynamic shapes * mypy * doc * changes * fix * fix
1 parent f3f167a commit cae7a79

File tree

9 files changed

+201
-6
lines changed

9 files changed

+201
-6
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.7.1
5+
+++++
6+
7+
* :pr:`152`: add a function to compute fully dynamic shapes given any inputs
8+
49
0.7.0
510
+++++
611

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ onnx_diagnostic.export
66
:caption: modules
77

88
dynamic_shapes
9+
shape_helper
910
validate
1011

1112
CoupleInputsDynamicShapes

_doc/api/export/shape_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.shape_helper
3+
===================================
4+
5+
.. automodule:: onnx_diagnostic.export.shape_helper
6+
:members:
7+
:no-undoc-members:

_doc/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,9 @@ The function replaces dynamic dimensions defined as strings by
211211
``torch.export.Dim.DYNAMIC``.
212212

213213
Older versions
214-
++++++++++++++
214+
==============
215215

216+
* `0.7.1 <../v0.7.1/index.html>`_
216217
* `0.7.0 <../v0.7.0/index.html>`_
217218
* `0.6.3 <../v0.6.3/index.html>`_
218219
* `0.5.0 <../v0.5.0/index.html>`_
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
Builds dynamic shapes from any input
3+
====================================
4+
5+
Getting dynamic shapes right for :func:`torch.export.export` when the inputs
6+
includes a custom class such as a :class:`transformers.cache_utils.DynamicCache`.
7+
:func:`torch.export.export` cannot use a DynamicCache filled with dynamic shapes
8+
but instead it uses a kind of unserialized serialized form of it.
9+
10+
Standard inputs for a LLM with a dynamic cache
11+
++++++++++++++++++++++++++++++++++++++++++++++
12+
"""
13+
14+
import pprint
15+
import torch
16+
from onnx_diagnostic import doc
17+
from onnx_diagnostic.helpers import string_type
18+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
20+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
21+
from onnx_diagnostic.torch_export_patches import torch_export_patches
22+
23+
bsize, nheads, slen, dim = 2, 1, 30, 96
24+
25+
inputs = dict(
26+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
27+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
28+
position_ids=torch.arange(3, dtype=torch.int64),
29+
past_key_values=make_dynamic_cache(
30+
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
31+
),
32+
)
33+
34+
print(string_type(inputs, with_shape=True))
35+
36+
# %%
37+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
38+
# produces the corresponding dynamic shapes assuming they are all dynamic.
39+
ds = all_dynamic_shape_from_inputs(inputs)
40+
pprint.pprint(ds)
41+
42+
# %%
43+
# What about a StaticCache?
44+
# +++++++++++++++++++++++++
45+
#
46+
# We use :func:`onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs` to get
47+
# a consistent configuration with a static cache.
48+
49+
data = get_untrained_model_with_inputs(
50+
"arnir0/Tiny-LLM",
51+
model_kwargs=dict(cache_implementation="static"),
52+
inputs_kwargs=dict(cls_cache="StaticCache"),
53+
)
54+
inputs = data["inputs"]
55+
print(string_type(inputs, with_shape=True))
56+
57+
# %%
58+
# And the input shapes.
59+
ds = all_dynamic_shape_from_inputs(inputs)
60+
if ds["past_key_values"]:
61+
print("transformers implemented serialization function for StaticCache.")
62+
else:
63+
print("We need to use serialization function implemented in this package.")
64+
with torch_export_patches(patch_transformers=True):
65+
ds = all_dynamic_shape_from_inputs(inputs)
66+
67+
# %%
68+
# That gives.
69+
pprint.pprint(ds)
70+
71+
# %%
72+
# We can compare with the ones returned by the function.
73+
pprint.pprint(data["dynamic_shapes"])
74+
75+
76+
# %%
77+
78+
doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
4+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
5+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
6+
7+
8+
class TestShapeHelper(ExtTestCase):
9+
@requires_transformers("4.52")
10+
@requires_torch("2.7.99")
11+
def test_all_dynamic_shape_from_inputs(self):
12+
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
13+
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
14+
ds = all_dynamic_shape_from_inputs(
15+
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
16+
)
17+
self.assertEqual(
18+
[
19+
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
20+
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
21+
],
22+
ds,
23+
)
24+
25+
@requires_transformers("4.52")
26+
@requires_torch("2.7.99")
27+
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
28+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
29+
print(self.string_type(data["inputs"], with_shape=True))
30+
ds = all_dynamic_shape_from_inputs(data["inputs"])
31+
self.assertEqual(
32+
{
33+
"input_ids": {0: "d_0_0", 1: "d_0_1"},
34+
"attention_mask": {0: "d_1_0", 1: "d_1_1"},
35+
"position_ids": {0: "d_2_0", 1: "d_2_1"},
36+
"past_key_values": {
37+
"key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
38+
"value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
39+
},
40+
},
41+
ds,
42+
)
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main(verbosity=2)

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.7.0"
6+
__version__ = "0.7.1"
77
__author__ = "Xavier Dupré"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Any, Set
2+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3+
4+
5+
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
6+
"""
7+
Returns the dynamic shapes for the given inputs.
8+
All dimensions are considered as dynamic.
9+
``dim_prefix`` can be a string (the function uses it as a prefix),
10+
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
11+
12+
.. runpython::
13+
:showcode:
14+
15+
import pprint
16+
import torch
17+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
18+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
19+
20+
bsize, nheads, slen, dim = 2, 1, 30, 96
21+
inputs = dict(
22+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
23+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
24+
position_ids=torch.arange(3, dtype=torch.int64),
25+
past_key_values=make_dynamic_cache(
26+
[(torch.randn(bsize, nheads, slen, dim),
27+
torch.randn(bsize, nheads, slen, dim))]
28+
),
29+
)
30+
ds = all_dynamic_shape_from_inputs(inputs)
31+
pprint.pprint(ds)
32+
"""
33+
if isinstance(dim_prefix, str):
34+
prefixes: Set[str] = set()
35+
36+
def tensor_to_shape(tensor):
37+
n = len(prefixes)
38+
p = f"{dim_prefix}_{n}"
39+
prefixes.add(p)
40+
return {i: f"{p}_{i}" for i in range(tensor.ndim)}
41+
42+
else:
43+
44+
def tensor_to_shape(tensor):
45+
return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
46+
47+
return flatten_unflatten_for_dynamic_shapes(
48+
inputs, change_function=tensor_to_shape, use_dict=True
49+
)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from typing import Any, List, Tuple
1+
from typing import Any, Callable, List, Optional, Tuple
22
import packaging.version as pv
33
import torch
44
import transformers
55
import transformers.cache_utils
66

77

8-
def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
8+
def flatten_unflatten_for_dynamic_shapes(
9+
obj: Any,
10+
use_dict: bool = False,
11+
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
12+
) -> Any:
913
"""
1014
Returns the object in a different structure similar to what
1115
the definition of the dynamic shapes should use.
@@ -16,18 +20,22 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
1620
the context gives the dictionary keys but it is not expressed
1721
in the dynamic shapes, these specifications seems to be different
1822
for the strict and non strict mode.
23+
:param change_function: to modifies the tensor in the structure itself,
24+
like replace them by a shape
1925
:return: the serialized object
2026
"""
2127
if isinstance(obj, torch.Tensor):
22-
return obj
28+
return change_function(obj) if change_function else obj
2329
flat, spec = torch.utils._pytree.tree_flatten(obj)
2430
start = 0
2531
end = 0
2632
subtrees = []
2733
for subspec in spec.children_specs:
2834
end += subspec.num_leaves
2935
value = subspec.unflatten(flat[start:end])
30-
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
36+
value = flatten_unflatten_for_dynamic_shapes(
37+
value, use_dict=use_dict, change_function=change_function
38+
)
3139
subtrees.append(value)
3240
start = end
3341
if use_dict and (spec.type is dict or spec.context):

0 commit comments

Comments
 (0)