Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def linkcode_resolve(domain, info):
"onnx-script": "https://github.com/microsoft/onnxscript",
"onnxscript": "https://github.com/microsoft/onnxscript",
"onnxscript Tutorial": "https://microsoft.github.io/onnxscript/tutorial/index.html",
"optree": "https://github.com/metaopt/optree",
"Pattern-based Rewrite Using Rules With onnxscript": "https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html",
"opsets": "https://onnx.ai/onnx/intro/concepts.html#what-is-an-opset-version",
"pyinstrument": "https://pyinstrument.readthedocs.io/en/latest/",
Expand Down
113 changes: 113 additions & 0 deletions _doc/recipes/plot_dynamic_shapes_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
JSON returns list when the original dynamic shapes are list or tuple
====================================================================

Dynamic shapes given to :func:`torch.export.export` must follow the
same semantic. What if we confuse tuple and list when defining the dynamic shapes,
how to restore the expected type assuming we know the inputs?
Not often useful but maybe we will learn more about
:epkg:`optree`.

Dynamic Shapes After JSON
+++++++++++++++++++++++++

JSON format does not make the difference between a list and a tuple.
So after serializing to json and restoring, both of them become lists.
"""

import json
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96

inputs = dict(
input_mask_position=(
torch.randint(15, size=(2, 3), dtype=torch.int64),
torch.randint(1, size=(2, 33), dtype=torch.int64),
torch.arange(3, dtype=torch.int64),
),
past_key_values=make_dynamic_cache(
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
),
)

print(string_type(inputs, with_shape=True))

# %%
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
# produces the corresponding dynamic shapes assuming they are all dynamic.
ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)

# %%
# Converted into JSON.

json_str = json.dumps(ds, indent=2, ensure_ascii=False)
print(json_str)

# %%
# Restoration.
ds2 = json.loads(json_str)
pprint.pprint(ds2)

# %%
# tuple are replaced by list.

# The trick to restore tuple when expected
# ++++++++++++++++++++++++++++++++++++++++


def flatten_unflatten_like_dynamic_shapes(obj):
if isinstance(obj, torch.Tensor):
return obj
flat, spec = torch.utils._pytree.tree_flatten(obj)
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_like_dynamic_shapes(value)
subtrees.append(value)
start = end
if spec.type is dict or spec.context:
return dict(zip(spec.context, subtrees))
if spec.type is tuple:
return tuple(subtrees)
return subtrees


def _align(inputs, ds):
if isinstance(inputs, torch.Tensor):
return ds
if isinstance(inputs, tuple):
return tuple(_align(o, d) for o, d in zip(inputs, ds))
if isinstance(inputs, list):
return [_align(o, d) for o, d in zip(inputs, ds)]
if isinstance(inputs, dict):
return {k: _align(inputs[k], d) for k, d in ds.items()}
raise TypeError(f"Unexpected types inputs is {type(inputs)}, ds is {type(ds)}")


def fix_dynamic_shapes(inputs, dynamic_shapes):
flat_unflat_inputs = flatten_unflatten_like_dynamic_shapes(inputs)
return _align(flat_unflat_inputs, dynamic_shapes)


fixed_ds = fix_dynamic_shapes(inputs, ds2)
pprint.pprint(fixed_ds)

# %%
# The code changed tuple into list as expected.
assert isinstance(ds2["input_mask_position"], list)
assert isinstance(fixed_ds["input_mask_position"], tuple)


# %%

doc.plot_legend("dynamic shapes\nto json\nfrom json", "torch.export.export", "green")
35 changes: 31 additions & 4 deletions _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
from onnx_diagnostic.export.shape_helper import (
all_dynamic_shape_from_inputs,
guess_dynamic_shapes_from_inputs,
)
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs


Expand All @@ -10,23 +13,24 @@ class TestShapeHelper(ExtTestCase):
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_inputs(self):
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
self.assertEqual(({0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}), ds)
ds = all_dynamic_shape_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))])
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
ds = all_dynamic_shape_from_inputs(
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
)
self.assertEqual(
[
(
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
],
),
ds,
)

@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
print(self.string_type(data["inputs"], with_shape=True))
ds = all_dynamic_shape_from_inputs(data["inputs"])
self.assertEqual(
{
Expand All @@ -41,6 +45,29 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
ds,
)

@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_guess_dynamic_shapes_from_inputs(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
guessed = guess_dynamic_shapes_from_inputs(
[data["inputs"], data["inputs2"]], auto="dd"
)
self.assertEqual(
(
(),
{
"attention_mask": {0: "dd_0I0", 1: "dd_0I1"},
"input_ids": {0: "dd_1I0", 1: "dd_1I1"},
"past_key_values": [
[{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}],
[{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}],
],
"position_ids": {0: "dd_3I0", 1: "dd_3I1"},
},
),
guessed,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
68 changes: 48 additions & 20 deletions onnx_diagnostic/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,12 @@ def __init__(
method_name: str = "forward",
name: str = "main",
):
assert isinstance(model, torch.nn.Module) or inspect.ismodule(
model
), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
assert (
model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
), (
f"unexpected type for model={type(model)}, "
f"it must be a torch.nn.Module or None"
)
assert name, (
f"name={name!r} cannot be empty this string is used to "
f"display meaningful error messages"
Expand All @@ -641,26 +644,42 @@ def __init__(
self.model = model
self.level = level
self.method_name = method_name
self.forward = getattr(model, method_name)
self.signature = inspect.signature(self.forward)
self.forward = getattr(model, method_name) if model is not None else None
self.signature = inspect.signature(self.forward) if self.forward else None

# information about the signature
self.forward_parameter_names = set(
p.name
for p in self.signature.parameters.values()
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
self.forward_parameter_names = (
set(
p.name
for p in self.signature.parameters.values()
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
)
if self.signature
else None
)
self.forward_ordered_parameter_names = (
list(self.signature.parameters) if self.signature else None
)
self.forward_positioned_parameter_names = (
[
p.name
for p in self.signature.parameters.values()
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
if self.signature
else None
)
names = (
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
if self.signature
else None
)
self.forward_ordered_parameter_names = list(self.signature.parameters)
self.forward_positioned_parameter_names = [
p.name
for p in self.signature.parameters.values()
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
names = [
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
]
self.forward_args = names[0] if names else None
names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
names = (
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
if self.signature
else None
)
self.forward_kwargs = names[0] if names else None
self.forward_custom_op_schema = None
self.forward_need_serialization = False
Expand Down Expand Up @@ -711,6 +730,7 @@ def process_inputs(
@property
def true_model_name(self) -> str:
"Returns class name or module name."
assert self.model is not None, "model was None when the class was initialized."
return (
self.model.__class__.__name__
if isinstance(self.model, torch.nn.Module)
Expand Down Expand Up @@ -942,7 +962,7 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES
)
)
names = s2.pop()
for name in names:
for i, name in enumerate(names):
assert name not in {"_diag", "verbose"}, (
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
Expand All @@ -968,6 +988,14 @@ def move_to_kwargs(
with the corresponding dynamic shapes.
*kwargs*, *dynamic_shapes* are modified inplace.
"""
assert (
self.signature is not None
and self.forward_parameter_names is not None
and self.forward_ordered_parameter_names is not None
), (
"model was None when the class was initialized, "
"cannot move args to kwargs without the signature."
)
sig = self.signature
arg_dyn, kw_dyn = dynamic_shapes
for i, p in enumerate(sig.parameters):
Expand Down
79 changes: 78 additions & 1 deletion onnx_diagnostic/export/shape_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Set
from typing import Any, Dict, List, Set, Tuple, Union
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from .dynamic_shapes import ModelInputs


def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
Expand Down Expand Up @@ -47,3 +48,79 @@ def tensor_to_shape(tensor):
return flatten_unflatten_for_dynamic_shapes(
inputs, change_function=tensor_to_shape, use_dict=True
)


def guess_dynamic_shapes_from_inputs(
inputs: List[Any], auto: Union[bool, str] = False
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Guesses which dimension is dimension from a set of inputs.
Every dimension having different values over multiple sets
of inputs. Every dimension not changing remains static.

:param inputs: a list of input sets
:param auto: True for ``torch.export.Dim.AUTO``,
False for ``torch.export.Dim.DYNAMIC``,
a string to get a unique string for every dynamic dimension
:return: args and kwargs

.. runpython::
:showcode:

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96
inputs1 = dict(
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.arange(3, dtype=torch.int64),
past_key_values=make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
),
]
),
)
bsize, nheads, slen, dim = 3, 1, 33, 96
inputs2 = dict(
input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
position_ids=torch.arange(4, dtype=torch.int64),
past_key_values=make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
),
]
),
)
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
pprint.pprint(ds)

This function returns something equivalent to function
:class:`torch.export.dynamic_shapes.AdditionalInputs` but this
one needs a model.

.. runpython::
:showcode:

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs

data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
ds = torch.export.dynamic_shapes.AdditionalInputs()
ds.add((), data["inputs"])
ds.add((), data["inputs2"])
pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
"""
mi = ModelInputs(None, inputs)
return mi.guess_dynamic_shapes(auto=auto)
Loading
Loading