Skip to content

Commit 6e6c6ae

Browse files
authored
json ambiguities (#154)
* json ambiguities * fix doc * doc * mypy * mypy
1 parent d2a78c4 commit 6e6c6ae

File tree

6 files changed

+278
-29
lines changed

6 files changed

+278
-29
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def linkcode_resolve(domain, info):
233233
"onnx-script": "https://github.com/microsoft/onnxscript",
234234
"onnxscript": "https://github.com/microsoft/onnxscript",
235235
"onnxscript Tutorial": "https://microsoft.github.io/onnxscript/tutorial/index.html",
236+
"optree": "https://github.com/metaopt/optree",
236237
"Pattern-based Rewrite Using Rules With onnxscript": "https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html",
237238
"opsets": "https://onnx.ai/onnx/intro/concepts.html#what-is-an-opset-version",
238239
"pyinstrument": "https://pyinstrument.readthedocs.io/en/latest/",
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
JSON returns list when the original dynamic shapes are list or tuple
3+
====================================================================
4+
5+
Dynamic shapes given to :func:`torch.export.export` must follow the
6+
same semantic. What if we confuse tuple and list when defining the dynamic shapes,
7+
how to restore the expected type assuming we know the inputs?
8+
Not often useful but maybe we will learn more about
9+
:epkg:`optree`.
10+
11+
Dynamic Shapes After JSON
12+
+++++++++++++++++++++++++
13+
14+
JSON format does not make the difference between a list and a tuple.
15+
So after serializing to json and restoring, both of them become lists.
16+
"""
17+
18+
import json
19+
import pprint
20+
import torch
21+
from onnx_diagnostic import doc
22+
from onnx_diagnostic.helpers import string_type
23+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
24+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
25+
26+
bsize, nheads, slen, dim = 2, 1, 30, 96
27+
28+
inputs = dict(
29+
input_mask_position=(
30+
torch.randint(15, size=(2, 3), dtype=torch.int64),
31+
torch.randint(1, size=(2, 33), dtype=torch.int64),
32+
torch.arange(3, dtype=torch.int64),
33+
),
34+
past_key_values=make_dynamic_cache(
35+
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
36+
),
37+
)
38+
39+
print(string_type(inputs, with_shape=True))
40+
41+
# %%
42+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
43+
# produces the corresponding dynamic shapes assuming they are all dynamic.
44+
ds = all_dynamic_shape_from_inputs(inputs)
45+
pprint.pprint(ds)
46+
47+
# %%
48+
# Converted into JSON.
49+
50+
json_str = json.dumps(ds, indent=2, ensure_ascii=False)
51+
print(json_str)
52+
53+
# %%
54+
# Restoration.
55+
ds2 = json.loads(json_str)
56+
pprint.pprint(ds2)
57+
58+
# %%
59+
# tuple are replaced by list.
60+
61+
# The trick to restore tuple when expected
62+
# ++++++++++++++++++++++++++++++++++++++++
63+
64+
65+
def flatten_unflatten_like_dynamic_shapes(obj):
66+
if isinstance(obj, torch.Tensor):
67+
return obj
68+
flat, spec = torch.utils._pytree.tree_flatten(obj)
69+
start = 0
70+
end = 0
71+
subtrees = []
72+
for subspec in spec.children_specs:
73+
end += subspec.num_leaves
74+
value = subspec.unflatten(flat[start:end])
75+
value = flatten_unflatten_like_dynamic_shapes(value)
76+
subtrees.append(value)
77+
start = end
78+
if spec.type is dict or spec.context:
79+
return dict(zip(spec.context, subtrees))
80+
if spec.type is tuple:
81+
return tuple(subtrees)
82+
return subtrees
83+
84+
85+
def _align(inputs, ds):
86+
if isinstance(inputs, torch.Tensor):
87+
return ds
88+
if isinstance(inputs, tuple):
89+
return tuple(_align(o, d) for o, d in zip(inputs, ds))
90+
if isinstance(inputs, list):
91+
return [_align(o, d) for o, d in zip(inputs, ds)]
92+
if isinstance(inputs, dict):
93+
return {k: _align(inputs[k], d) for k, d in ds.items()}
94+
raise TypeError(f"Unexpected types inputs is {type(inputs)}, ds is {type(ds)}")
95+
96+
97+
def fix_dynamic_shapes(inputs, dynamic_shapes):
98+
flat_unflat_inputs = flatten_unflatten_like_dynamic_shapes(inputs)
99+
return _align(flat_unflat_inputs, dynamic_shapes)
100+
101+
102+
fixed_ds = fix_dynamic_shapes(inputs, ds2)
103+
pprint.pprint(fixed_ds)
104+
105+
# %%
106+
# The code changed tuple into list as expected.
107+
assert isinstance(ds2["input_mask_position"], list)
108+
assert isinstance(fixed_ds["input_mask_position"], tuple)
109+
110+
111+
# %%
112+
113+
doc.plot_legend("dynamic shapes\nto json\nfrom json", "torch.export.export", "green")

_unittests/ut_export/test_shape_helper.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
4-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
4+
from onnx_diagnostic.export.shape_helper import (
5+
all_dynamic_shape_from_inputs,
6+
guess_dynamic_shapes_from_inputs,
7+
)
58
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
69

710

@@ -10,23 +13,24 @@ class TestShapeHelper(ExtTestCase):
1013
@requires_torch("2.7.99")
1114
def test_all_dynamic_shape_from_inputs(self):
1215
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
16+
self.assertEqual(({0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}), ds)
17+
ds = all_dynamic_shape_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
1318
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
1419
ds = all_dynamic_shape_from_inputs(
1520
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
1621
)
1722
self.assertEqual(
18-
[
23+
(
1924
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
2025
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
21-
],
26+
),
2227
ds,
2328
)
2429

2530
@requires_transformers("4.52")
2631
@requires_torch("2.7.99")
2732
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
2833
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
29-
print(self.string_type(data["inputs"], with_shape=True))
3034
ds = all_dynamic_shape_from_inputs(data["inputs"])
3135
self.assertEqual(
3236
{
@@ -41,6 +45,29 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
4145
ds,
4246
)
4347

48+
@requires_transformers("4.52")
49+
@requires_torch("2.7.99")
50+
def test_guess_dynamic_shapes_from_inputs(self):
51+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
52+
guessed = guess_dynamic_shapes_from_inputs(
53+
[data["inputs"], data["inputs2"]], auto="dd"
54+
)
55+
self.assertEqual(
56+
(
57+
(),
58+
{
59+
"attention_mask": {0: "dd_0I0", 1: "dd_0I1"},
60+
"input_ids": {0: "dd_1I0", 1: "dd_1I1"},
61+
"past_key_values": [
62+
[{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}],
63+
[{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}],
64+
],
65+
"position_ids": {0: "dd_3I0", 1: "dd_3I1"},
66+
},
67+
),
68+
guessed,
69+
)
70+
4471

4572
if __name__ == "__main__":
4673
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -630,9 +630,12 @@ def __init__(
630630
method_name: str = "forward",
631631
name: str = "main",
632632
):
633-
assert isinstance(model, torch.nn.Module) or inspect.ismodule(
634-
model
635-
), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
633+
assert (
634+
model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
635+
), (
636+
f"unexpected type for model={type(model)}, "
637+
f"it must be a torch.nn.Module or None"
638+
)
636639
assert name, (
637640
f"name={name!r} cannot be empty this string is used to "
638641
f"display meaningful error messages"
@@ -641,26 +644,42 @@ def __init__(
641644
self.model = model
642645
self.level = level
643646
self.method_name = method_name
644-
self.forward = getattr(model, method_name)
645-
self.signature = inspect.signature(self.forward)
647+
self.forward = getattr(model, method_name) if model is not None else None
648+
self.signature = inspect.signature(self.forward) if self.forward else None
646649

647650
# information about the signature
648-
self.forward_parameter_names = set(
649-
p.name
650-
for p in self.signature.parameters.values()
651-
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
651+
self.forward_parameter_names = (
652+
set(
653+
p.name
654+
for p in self.signature.parameters.values()
655+
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
656+
)
657+
if self.signature
658+
else None
659+
)
660+
self.forward_ordered_parameter_names = (
661+
list(self.signature.parameters) if self.signature else None
662+
)
663+
self.forward_positioned_parameter_names = (
664+
[
665+
p.name
666+
for p in self.signature.parameters.values()
667+
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
668+
]
669+
if self.signature
670+
else None
671+
)
672+
names = (
673+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
674+
if self.signature
675+
else None
652676
)
653-
self.forward_ordered_parameter_names = list(self.signature.parameters)
654-
self.forward_positioned_parameter_names = [
655-
p.name
656-
for p in self.signature.parameters.values()
657-
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
658-
]
659-
names = [
660-
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
661-
]
662677
self.forward_args = names[0] if names else None
663-
names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
678+
names = (
679+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
680+
if self.signature
681+
else None
682+
)
664683
self.forward_kwargs = names[0] if names else None
665684
self.forward_custom_op_schema = None
666685
self.forward_need_serialization = False
@@ -711,6 +730,7 @@ def process_inputs(
711730
@property
712731
def true_model_name(self) -> str:
713732
"Returns class name or module name."
733+
assert self.model is not None, "model was None when the class was initialized."
714734
return (
715735
self.model.__class__.__name__
716736
if isinstance(self.model, torch.nn.Module)
@@ -942,7 +962,7 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES
942962
)
943963
)
944964
names = s2.pop()
945-
for name in names:
965+
for i, name in enumerate(names):
946966
assert name not in {"_diag", "verbose"}, (
947967
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
948968
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
@@ -968,6 +988,14 @@ def move_to_kwargs(
968988
with the corresponding dynamic shapes.
969989
*kwargs*, *dynamic_shapes* are modified inplace.
970990
"""
991+
assert (
992+
self.signature is not None
993+
and self.forward_parameter_names is not None
994+
and self.forward_ordered_parameter_names is not None
995+
), (
996+
"model was None when the class was initialized, "
997+
"cannot move args to kwargs without the signature."
998+
)
971999
sig = self.signature
9721000
arg_dyn, kw_dyn = dynamic_shapes
9731001
for i, p in enumerate(sig.parameters):

onnx_diagnostic/export/shape_helper.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Set
1+
from typing import Any, Dict, List, Set, Tuple, Union
22
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3+
from .dynamic_shapes import ModelInputs
34

45

56
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
@@ -47,3 +48,79 @@ def tensor_to_shape(tensor):
4748
return flatten_unflatten_for_dynamic_shapes(
4849
inputs, change_function=tensor_to_shape, use_dict=True
4950
)
51+
52+
53+
def guess_dynamic_shapes_from_inputs(
54+
inputs: List[Any], auto: Union[bool, str] = False
55+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
56+
"""
57+
Guesses which dimension is dimension from a set of inputs.
58+
Every dimension having different values over multiple sets
59+
of inputs. Every dimension not changing remains static.
60+
61+
:param inputs: a list of input sets
62+
:param auto: True for ``torch.export.Dim.AUTO``,
63+
False for ``torch.export.Dim.DYNAMIC``,
64+
a string to get a unique string for every dynamic dimension
65+
:return: args and kwargs
66+
67+
.. runpython::
68+
:showcode:
69+
70+
import pprint
71+
import torch
72+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
73+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
74+
75+
bsize, nheads, slen, dim = 2, 1, 30, 96
76+
inputs1 = dict(
77+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
78+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
79+
position_ids=torch.arange(3, dtype=torch.int64),
80+
past_key_values=make_dynamic_cache(
81+
[
82+
(
83+
torch.randn(bsize, nheads, slen, dim),
84+
torch.randn(bsize, nheads, slen, dim),
85+
),
86+
]
87+
),
88+
)
89+
bsize, nheads, slen, dim = 3, 1, 33, 96
90+
inputs2 = dict(
91+
input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
92+
attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
93+
position_ids=torch.arange(4, dtype=torch.int64),
94+
past_key_values=make_dynamic_cache(
95+
[
96+
(
97+
torch.randn(bsize, nheads, slen, dim),
98+
torch.randn(bsize, nheads, slen, dim),
99+
),
100+
]
101+
),
102+
)
103+
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
104+
pprint.pprint(ds)
105+
106+
This function returns something equivalent to function
107+
:class:`torch.export.dynamic_shapes.AdditionalInputs` but this
108+
one needs a model.
109+
110+
.. runpython::
111+
:showcode:
112+
113+
import pprint
114+
import torch
115+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
116+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
117+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
118+
119+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
120+
ds = torch.export.dynamic_shapes.AdditionalInputs()
121+
ds.add((), data["inputs"])
122+
ds.add((), data["inputs2"])
123+
pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
124+
"""
125+
mi = ModelInputs(None, inputs)
126+
return mi.guess_dynamic_shapes(auto=auto)

0 commit comments

Comments
 (0)