Skip to content

Commit bbba6df

Browse files
committed
doc
1 parent 6b2b993 commit bbba6df

File tree

3 files changed

+144
-22
lines changed

3 files changed

+144
-22
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
4-
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
4+
from onnx_diagnostic.export.shape_helper import (
5+
all_dynamic_shape_from_inputs,
6+
guess_dynamic_shapes_from_inputs,
7+
)
58
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
69

710

@@ -42,6 +45,29 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
4245
ds,
4346
)
4447

48+
@requires_transformers("4.52")
49+
@requires_torch("2.7.99")
50+
def test_guess_dynamic_shapes_from_inputs(self):
51+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
52+
guessed = guess_dynamic_shapes_from_inputs(
53+
[data["inputs"], data["inputs2"]], auto="dd"
54+
)
55+
self.assertEqual(
56+
(
57+
(),
58+
{
59+
"attention_mask": {0: "dd_0I0", 1: "dd_0I1"},
60+
"input_ids": {0: "dd_1I0", 1: "dd_1I1"},
61+
"past_key_values": [
62+
[{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}],
63+
[{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}],
64+
],
65+
"position_ids": {0: "dd_3I0", 1: "dd_3I1"},
66+
},
67+
),
68+
guessed,
69+
)
70+
4571

4672
if __name__ == "__main__":
4773
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -630,9 +630,12 @@ def __init__(
630630
method_name: str = "forward",
631631
name: str = "main",
632632
):
633-
assert isinstance(model, torch.nn.Module) or inspect.ismodule(
634-
model
635-
), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
633+
assert (
634+
model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
635+
), (
636+
f"unexpected type for model={type(model)}, "
637+
f"it must be a torch.nn.Module or None"
638+
)
636639
assert name, (
637640
f"name={name!r} cannot be empty this string is used to "
638641
f"display meaningful error messages"
@@ -641,26 +644,42 @@ def __init__(
641644
self.model = model
642645
self.level = level
643646
self.method_name = method_name
644-
self.forward = getattr(model, method_name)
645-
self.signature = inspect.signature(self.forward)
647+
self.forward = getattr(model, method_name) if model is not None else None
648+
self.signature = inspect.signature(self.forward) if self.forward else None
646649

647650
# information about the signature
648-
self.forward_parameter_names = set(
649-
p.name
650-
for p in self.signature.parameters.values()
651-
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
651+
self.forward_parameter_names = (
652+
set(
653+
p.name
654+
for p in self.signature.parameters.values()
655+
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
656+
)
657+
if self.signature
658+
else None
659+
)
660+
self.forward_ordered_parameter_names = (
661+
list(self.signature.parameters) if self.signature else None
662+
)
663+
self.forward_positioned_parameter_names = (
664+
[
665+
p.name
666+
for p in self.signature.parameters.values()
667+
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
668+
]
669+
if self.signature
670+
else None
671+
)
672+
names = (
673+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
674+
if self.signature
675+
else None
652676
)
653-
self.forward_ordered_parameter_names = list(self.signature.parameters)
654-
self.forward_positioned_parameter_names = [
655-
p.name
656-
for p in self.signature.parameters.values()
657-
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
658-
]
659-
names = [
660-
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
661-
]
662677
self.forward_args = names[0] if names else None
663-
names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
678+
names = (
679+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
680+
if self.signature
681+
else None
682+
)
664683
self.forward_kwargs = names[0] if names else None
665684
self.forward_custom_op_schema = None
666685
self.forward_need_serialization = False
@@ -942,7 +961,7 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES
942961
)
943962
)
944963
names = s2.pop()
945-
for name in names:
964+
for i, name in enumerate(names):
946965
assert name not in {"_diag", "verbose"}, (
947966
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
948967
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"

onnx_diagnostic/export/shape_helper.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Set
1+
from typing import Any, Dict, List, Set, Tuple, Union
22
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3+
from .dynamic_shapes import ModelInputs
34

45

56
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
@@ -47,3 +48,79 @@ def tensor_to_shape(tensor):
4748
return flatten_unflatten_for_dynamic_shapes(
4849
inputs, change_function=tensor_to_shape, use_dict=True
4950
)
51+
52+
53+
def guess_dynamic_shapes_from_inputs(
54+
inputs: List[Any], auto: Union[bool, str] = False
55+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
56+
"""
57+
Guesses which dimension is dimension from a set of inputs.
58+
Every dimension having different values over multiple sets
59+
of inputs. Every dimension not changing remains static.
60+
61+
:param inputs: a list of input sets
62+
:param auto: True for ``torch.export.Dim.AUTO``,
63+
False for ``torch.export.Dim.DYNAMIC``,
64+
a string to get a unique string for every dynamic dimension
65+
:return: args and kwargs
66+
67+
.. runpython::
68+
:showcode:
69+
70+
import pprint
71+
import torch
72+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
73+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
74+
75+
bsize, nheads, slen, dim = 2, 1, 30, 96
76+
inputs1 = dict(
77+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
78+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
79+
position_ids=torch.arange(3, dtype=torch.int64),
80+
past_key_values=make_dynamic_cache(
81+
[
82+
(
83+
torch.randn(bsize, nheads, slen, dim),
84+
torch.randn(bsize, nheads, slen, dim),
85+
),
86+
]
87+
),
88+
)
89+
bsize, nheads, slen, dim = 3, 1, 33, 96
90+
inputs2 = dict(
91+
input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
92+
attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
93+
position_ids=torch.arange(4, dtype=torch.int64),
94+
past_key_values=make_dynamic_cache(
95+
[
96+
(
97+
torch.randn(bsize, nheads, slen, dim),
98+
torch.randn(bsize, nheads, slen, dim),
99+
),
100+
]
101+
),
102+
)
103+
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
104+
pprint.pprint(ds)
105+
106+
This function returns sometihng equivalent to function
107+
:class:`torch.export.dynamic_shapes.AdditionalInputs` but this
108+
one needs a model.
109+
110+
.. runpython::
111+
:showcode:
112+
113+
import pprint
114+
import torch
115+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
116+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
117+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
118+
119+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
120+
ds = torch.export.dynamic_shapes.AdditionalInputs()
121+
ds.add((), data["inputs"])
122+
ds.add((), data["inputs2"])
123+
pprint.pprint(ds.dynamic_shapes(data["model"], (), inputs1))
124+
"""
125+
mi = ModelInputs(None, inputs)
126+
return mi.guess_dynamic_shapes(auto=auto)

0 commit comments

Comments
 (0)