diff --git a/_doc/conf.py b/_doc/conf.py index 72a4fd94..7539e6fc 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -233,6 +233,7 @@ def linkcode_resolve(domain, info): "onnx-script": "https://github.com/microsoft/onnxscript", "onnxscript": "https://github.com/microsoft/onnxscript", "onnxscript Tutorial": "https://microsoft.github.io/onnxscript/tutorial/index.html", + "optree": "https://github.com/metaopt/optree", "Pattern-based Rewrite Using Rules With onnxscript": "https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html", "opsets": "https://onnx.ai/onnx/intro/concepts.html#what-is-an-opset-version", "pyinstrument": "https://pyinstrument.readthedocs.io/en/latest/", diff --git a/_doc/recipes/plot_dynamic_shapes_json.py b/_doc/recipes/plot_dynamic_shapes_json.py new file mode 100644 index 00000000..e995d8ca --- /dev/null +++ b/_doc/recipes/plot_dynamic_shapes_json.py @@ -0,0 +1,113 @@ +""" +JSON returns list when the original dynamic shapes are list or tuple +==================================================================== + +Dynamic shapes given to :func:`torch.export.export` must follow the +same semantic. What if we confuse tuple and list when defining the dynamic shapes, +how to restore the expected type assuming we know the inputs? +Not often useful but maybe we will learn more about +:epkg:`optree`. + +Dynamic Shapes After JSON ++++++++++++++++++++++++++ + +JSON format does not make the difference between a list and a tuple. +So after serializing to json and restoring, both of them become lists. +""" + +import json +import pprint +import torch +from onnx_diagnostic import doc +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache +from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs + +bsize, nheads, slen, dim = 2, 1, 30, 96 + +inputs = dict( + input_mask_position=( + torch.randint(15, size=(2, 3), dtype=torch.int64), + torch.randint(1, size=(2, 33), dtype=torch.int64), + torch.arange(3, dtype=torch.int64), + ), + past_key_values=make_dynamic_cache( + [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] + ), +) + +print(string_type(inputs, with_shape=True)) + +# %% +# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs` +# produces the corresponding dynamic shapes assuming they are all dynamic. +ds = all_dynamic_shape_from_inputs(inputs) +pprint.pprint(ds) + +# %% +# Converted into JSON. + +json_str = json.dumps(ds, indent=2, ensure_ascii=False) +print(json_str) + +# %% +# Restoration. +ds2 = json.loads(json_str) +pprint.pprint(ds2) + +# %% +# tuple are replaced by list. + +# The trick to restore tuple when expected +# ++++++++++++++++++++++++++++++++++++++++ + + +def flatten_unflatten_like_dynamic_shapes(obj): + if isinstance(obj, torch.Tensor): + return obj + flat, spec = torch.utils._pytree.tree_flatten(obj) + start = 0 + end = 0 + subtrees = [] + for subspec in spec.children_specs: + end += subspec.num_leaves + value = subspec.unflatten(flat[start:end]) + value = flatten_unflatten_like_dynamic_shapes(value) + subtrees.append(value) + start = end + if spec.type is dict or spec.context: + return dict(zip(spec.context, subtrees)) + if spec.type is tuple: + return tuple(subtrees) + return subtrees + + +def _align(inputs, ds): + if isinstance(inputs, torch.Tensor): + return ds + if isinstance(inputs, tuple): + return tuple(_align(o, d) for o, d in zip(inputs, ds)) + if isinstance(inputs, list): + return [_align(o, d) for o, d in zip(inputs, ds)] + if isinstance(inputs, dict): + return {k: _align(inputs[k], d) for k, d in ds.items()} + raise TypeError(f"Unexpected types inputs is {type(inputs)}, ds is {type(ds)}") + + +def fix_dynamic_shapes(inputs, dynamic_shapes): + flat_unflat_inputs = flatten_unflatten_like_dynamic_shapes(inputs) + return _align(flat_unflat_inputs, dynamic_shapes) + + +fixed_ds = fix_dynamic_shapes(inputs, ds2) +pprint.pprint(fixed_ds) + +# %% +# The code changed tuple into list as expected. +assert isinstance(ds2["input_mask_position"], list) +assert isinstance(fixed_ds["input_mask_position"], tuple) + + +# %% + +doc.plot_legend("dynamic shapes\nto json\nfrom json", "torch.export.export", "green") diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 69c87399..30c9581c 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -1,7 +1,10 @@ import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch -from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs +from onnx_diagnostic.export.shape_helper import ( + all_dynamic_shape_from_inputs, + guess_dynamic_shapes_from_inputs, +) from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs @@ -10,15 +13,17 @@ class TestShapeHelper(ExtTestCase): @requires_torch("2.7.99") def test_all_dynamic_shape_from_inputs(self): ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6)))) + self.assertEqual(({0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}), ds) + ds = all_dynamic_shape_from_inputs([torch.randn((5, 6)), torch.randn((1, 6))]) self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds) ds = all_dynamic_shape_from_inputs( (torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO ) self.assertEqual( - [ + ( {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}, {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}, - ], + ), ds, ) @@ -26,7 +31,6 @@ def test_all_dynamic_shape_from_inputs(self): @requires_torch("2.7.99") def test_all_dynamic_shape_from_inputs_dynamic_cache(self): data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") - print(self.string_type(data["inputs"], with_shape=True)) ds = all_dynamic_shape_from_inputs(data["inputs"]) self.assertEqual( { @@ -41,6 +45,29 @@ def test_all_dynamic_shape_from_inputs_dynamic_cache(self): ds, ) + @requires_transformers("4.52") + @requires_torch("2.7.99") + def test_guess_dynamic_shapes_from_inputs(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True) + guessed = guess_dynamic_shapes_from_inputs( + [data["inputs"], data["inputs2"]], auto="dd" + ) + self.assertEqual( + ( + (), + { + "attention_mask": {0: "dd_0I0", 1: "dd_0I1"}, + "input_ids": {0: "dd_1I0", 1: "dd_1I1"}, + "past_key_values": [ + [{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}], + [{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}], + ], + "position_ids": {0: "dd_3I0", 1: "dd_3I1"}, + }, + ), + guessed, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 932a4329..c5d0695d 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -630,9 +630,12 @@ def __init__( method_name: str = "forward", name: str = "main", ): - assert isinstance(model, torch.nn.Module) or inspect.ismodule( - model - ), f"unexpected type for model={type(model)}, it must be a torch.nn.Module" + assert ( + model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model) + ), ( + f"unexpected type for model={type(model)}, " + f"it must be a torch.nn.Module or None" + ) assert name, ( f"name={name!r} cannot be empty this string is used to " f"display meaningful error messages" @@ -641,26 +644,42 @@ def __init__( self.model = model self.level = level self.method_name = method_name - self.forward = getattr(model, method_name) - self.signature = inspect.signature(self.forward) + self.forward = getattr(model, method_name) if model is not None else None + self.signature = inspect.signature(self.forward) if self.forward else None # information about the signature - self.forward_parameter_names = set( - p.name - for p in self.signature.parameters.values() - if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD} + self.forward_parameter_names = ( + set( + p.name + for p in self.signature.parameters.values() + if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD} + ) + if self.signature + else None + ) + self.forward_ordered_parameter_names = ( + list(self.signature.parameters) if self.signature else None + ) + self.forward_positioned_parameter_names = ( + [ + p.name + for p in self.signature.parameters.values() + if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + ] + if self.signature + else None + ) + names = ( + [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL] + if self.signature + else None ) - self.forward_ordered_parameter_names = list(self.signature.parameters) - self.forward_positioned_parameter_names = [ - p.name - for p in self.signature.parameters.values() - if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) - ] - names = [ - p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL - ] self.forward_args = names[0] if names else None - names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD] + names = ( + [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD] + if self.signature + else None + ) self.forward_kwargs = names[0] if names else None self.forward_custom_op_schema = None self.forward_need_serialization = False @@ -711,6 +730,7 @@ def process_inputs( @property def true_model_name(self) -> str: "Returns class name or module name." + assert self.model is not None, "model was None when the class was initialized." return ( self.model.__class__.__name__ if isinstance(self.model, torch.nn.Module) @@ -942,7 +962,7 @@ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES ) ) names = s2.pop() - for name in names: + for i, name in enumerate(names): assert name not in {"_diag", "verbose"}, ( f"{self.full_name}: unexpected parameter {name!r}, names={names}" f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}" @@ -968,6 +988,14 @@ def move_to_kwargs( with the corresponding dynamic shapes. *kwargs*, *dynamic_shapes* are modified inplace. """ + assert ( + self.signature is not None + and self.forward_parameter_names is not None + and self.forward_ordered_parameter_names is not None + ), ( + "model was None when the class was initialized, " + "cannot move args to kwargs without the signature." + ) sig = self.signature arg_dyn, kw_dyn = dynamic_shapes for i, p in enumerate(sig.parameters): diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py index 98c91025..489e65f0 100644 --- a/onnx_diagnostic/export/shape_helper.py +++ b/onnx_diagnostic/export/shape_helper.py @@ -1,5 +1,6 @@ -from typing import Any, Set +from typing import Any, Dict, List, Set, Tuple, Union from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes +from .dynamic_shapes import ModelInputs def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: @@ -47,3 +48,79 @@ def tensor_to_shape(tensor): return flatten_unflatten_for_dynamic_shapes( inputs, change_function=tensor_to_shape, use_dict=True ) + + +def guess_dynamic_shapes_from_inputs( + inputs: List[Any], auto: Union[bool, str] = False +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """ + Guesses which dimension is dimension from a set of inputs. + Every dimension having different values over multiple sets + of inputs. Every dimension not changing remains static. + + :param inputs: a list of input sets + :param auto: True for ``torch.export.Dim.AUTO``, + False for ``torch.export.Dim.DYNAMIC``, + a string to get a unique string for every dynamic dimension + :return: args and kwargs + + .. runpython:: + :showcode: + + import pprint + import torch + from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache + from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs + + bsize, nheads, slen, dim = 2, 1, 30, 96 + inputs1 = dict( + input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), + attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), + position_ids=torch.arange(3, dtype=torch.int64), + past_key_values=make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ), + ] + ), + ) + bsize, nheads, slen, dim = 3, 1, 33, 96 + inputs2 = dict( + input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64), + attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64), + position_ids=torch.arange(4, dtype=torch.int64), + past_key_values=make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ), + ] + ), + ) + ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d") + pprint.pprint(ds) + + This function returns something equivalent to function + :class:`torch.export.dynamic_shapes.AdditionalInputs` but this + one needs a model. + + .. runpython:: + :showcode: + + import pprint + import torch + from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache + from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True) + ds = torch.export.dynamic_shapes.AdditionalInputs() + ds.add((), data["inputs"]) + ds.add((), data["inputs2"]) + pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"])) + """ + mi = ModelInputs(None, inputs) + return mi.guess_dynamic_shapes(auto=auto) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 9b1ba31c..ae4556fd 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -19,7 +19,7 @@ def flatten_unflatten_for_dynamic_shapes( :func:`torch.export.export` only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different - for the strict and non strict mode. + for the strict and non strict mode. It also preserves tuple. :param change_function: to modifies the tensor in the structure itself, like replace them by a shape :return: the serialized object @@ -38,9 +38,12 @@ def flatten_unflatten_for_dynamic_shapes( ) subtrees.append(value) start = end - if use_dict and (spec.type is dict or spec.context): - # This a dictionary. - return dict(zip(spec.context, subtrees)) + if use_dict: + if spec.type is dict or spec.context: + # This a dictionary. + return dict(zip(spec.context, subtrees)) + if spec.type is tuple: + return tuple(subtrees) # This is a list. return subtrees