diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3dd067aa..a5286012 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,10 @@ Change Logs 0.4.0 +++++ +* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``, + use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes + for a specific models, it is a valid definition for ``torch.onnx.export`` + which can reuse the names * :pr:`55`: add support for text-classification * :pr:`54`: add support for fill-mask, refactoring * :pr:`52`: add support for zero-shot-image-classification diff --git a/_doc/examples/plot_export_hub_codellama.py b/_doc/examples/plot_export_hub_codellama.py index 628f2a09..533ad0dd 100644 --- a/_doc/examples/plot_export_hub_codellama.py +++ b/_doc/examples/plot_export_hub_codellama.py @@ -30,6 +30,7 @@ task_from_id, ) from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str model_id = "codellama/CodeLlama-7b-Python-hf" print("info", get_model_info(model_id)) @@ -96,7 +97,7 @@ model, (), kwargs=f(data["inputs"]), - dynamic_shapes=data["dynamic_shapes"], + dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), strict=False, ) print(ep) diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index 7022c981..ddcf1480 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -27,6 +27,7 @@ from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered from onnx_diagnostic.helpers.rt_helper import make_feeds from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_models.hghub import ( get_untrained_model_with_inputs, ) @@ -92,7 +93,7 @@ untrained_model, (), kwargs=modificator(copy.deepcopy(inputs)), - dynamic_shapes=dynamic_shapes, + dynamic_shapes=use_dyn_not_str(dynamic_shapes), strict=False, # mandatory for torch==2.6 ) diff --git a/_doc/index.rst b/_doc/index.rst index 85b3b282..b449d7e5 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -20,6 +20,7 @@ onnx-diagnostic: investigate onnx models The main feature is about `patches `_: it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches. +Sources available at `github/onnx-diagnostic `_. .. code-block:: python diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 08da1426..be8fb56c 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -576,17 +576,20 @@ def test_couple_input_ds_cache(self): Cls( (), kwargs, - {"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])}, + { + "A": ds_batch, + "B": (ds_batch, [[ds_batch, ds_batch], [ds_batch, ds_batch]]), + }, ).invalid_dimensions_for_export(), ) self.assertEqual( - {"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])}, + {"B": (None, [[None, {2: "d=[1]"}], [None, {2: "d=[1]"}]])}, Cls( (), kwargs, { "A": ds_batch, - "B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]), + "B": (ds_batch, [[ds_batch, ds_batch_seq], [ds_batch, ds_batch_seq]]), }, ).invalid_dimensions_for_export(), ) @@ -762,10 +765,8 @@ def test_dynamic_cache_replace_by_string(self): self.assertEqual( { "cache": [ - {0: "Dim0", 1: "Dim1"}, - {0: "Dim2", 1: "Dim3"}, - {0: "Dim4", 1: "Dim5"}, - {0: "Dim6", 1: "Dim7"}, + [{0: "Dim0", 1: "Dim1"}, {0: "Dim2", 1: "Dim3"}], + [{0: "Dim4", 1: "Dim5"}, {0: "Dim6", 1: "Dim7"}], ] }, as_string, diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py new file mode 100644 index 00000000..4ff13042 --- /dev/null +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -0,0 +1,71 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache +from onnx_diagnostic.export import CoupleInputsDynamicShapes +from onnx_diagnostic.torch_export_patches.patch_inputs import ( + convert_dynamic_axes_into_dynamic_shapes, +) +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors + + +class TestCacheHelpers(ExtTestCase): + def test_string_type(self): + DYN = torch.export.Dim.DYNAMIC + self.assertEqual("DYNAMIC", string_type(DYN, verbose=0)) + AUTO = torch.export.Dim.AUTO + self.assertEqual("AUTO", string_type(AUTO, verbose=0)) + self.assertEqual("#1[DYNAMIC]", string_type([DYN])) + + batch = torch.export.Dim("batch") + dynamic_shapes = dict( + input_ids={0: batch, 1: "seq"}, + attention_mask={0: batch, 1: "seq"}, + position_ids={0: batch, 1: "seq"}, + past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]], + ) + self.assertEqual( + "dict(input_ids:{0:Dim(batch),1:DYN(seq)}," + "attention_mask:{0:Dim(batch),1:DYN(seq)}," + "position_ids:{0:Dim(batch),1:DYN(seq)}," + "past_key_values:#2[#1[{0:Dim(batch),2:DYN(seq)}]," + "#1[{0:Dim(batch),2:DYN(seq)}]])", + string_type(dynamic_shapes), + ) + + def test_replace_by(self): + bsize, nheads, slen, dim = 2, 4, 3, 7 + + past_key_values = make_dynamic_cache( + [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] + ) + kwargs = dict( + input_ids=torch.zeros(2, 3), + attention_mask=torch.zeros(2, 3), + position_ids=torch.zeros(2, 3), + past_key_values=past_key_values, + ) + batch = torch.export.Dim("batch") + dynamic_shapes = dict( + input_ids={0: batch, 1: "seq"}, + attention_mask={0: batch, 1: "seq"}, + position_ids={0: batch, 1: "seq"}, + past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]], + ) + + DYN = torch.export.Dim.DYNAMIC + nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes( + None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes + ) + self.assertEqual(dynamic_shapes, nds) + + with bypass_export_some_errors(patch_transformers=True): + cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes) + res = cpl.replace_string_by() + dsc = res["past_key_values"] + self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_inputs.py b/_unittests/ut_torch_export_patches/test_patch_inputs.py index ab7e9575..e148efde 100644 --- a/_unittests/ut_torch_export_patches/test_patch_inputs.py +++ b/_unittests/ut_torch_export_patches/test_patch_inputs.py @@ -5,6 +5,7 @@ from onnx_diagnostic.helpers import string_type from onnx_diagnostic.torch_export_patches.patch_inputs import ( convert_dynamic_axes_into_dynamic_shapes, + use_dyn_not_str, ) @@ -111,6 +112,26 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self): string_type(res[1], with_shape=True), ) + def test_use_dyn_not_str(self): + batch = torch.export.Dim("batch") + dynamic_shapes = dict( + input_ids={0: batch, 1: "seq"}, + attention_mask={0: batch, 1: "seq"}, + position_ids={0: batch, 1: "seq"}, + past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]], + ) + res = use_dyn_not_str(dynamic_shapes) + DYN = torch.export.Dim.DYNAMIC + self.assertEqual( + dict( + input_ids={0: batch, 1: DYN}, + attention_mask={0: batch, 1: DYN}, + position_ids={0: batch, 1: DYN}, + past_key_values=[[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], + ), + res, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index a1c99a89..d2a32b84 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -2,7 +2,12 @@ import unittest import packaging.version as pv import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + ignore_warnings, + requires_torch, +) from onnx_diagnostic.torch_models.test_helper import ( get_inputs_for_task, validate_model, @@ -54,6 +59,7 @@ def test_validate_model_export(self): self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) + @requires_torch("2.7") @hide_stdout() @ignore_warnings(FutureWarning) def test_validate_model_onnx(self): diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index b53e196b..d1edec06 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -302,7 +302,7 @@ def get_parser_validate() -> ArgumentParser: def _cmd_validate(argv: List[Any]): from .helpers import string_type - from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean + from .torch_models.test_helper import get_inputs_for_task, validate_model from .tasks import supported_tasks parser = get_parser_validate() @@ -320,7 +320,7 @@ def _cmd_validate(argv: List[Any]): print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}") print("-- dynamic_shapes") for k, v in data["dynamic_shapes"].items(): - print(f" + {k.ljust(max_length)}: {_ds_clean(v)}") + print(f" + {k.ljust(max_length)}: {string_type(v)}") else: # Let's skip any invalid combination if known to be unsupported if ( diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 987346fc..944c7bc0 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -92,7 +92,8 @@ def replace_string_by(self, value: Any = None): return self._generic_walker( lambda inputs, ds, value=value: self._replace_string_dim_tensor( inputs, ds, value=value - ) + ), + flatten_unflatten=True, ) @classmethod @@ -135,7 +136,8 @@ def replace_by_string(self): return self._generic_walker( lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string( inputs, ds, unique=unique - ) + ), + flatten_unflatten=True, ) @classmethod @@ -203,7 +205,7 @@ def invalid_dimensions_for_export(self): ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export()) """ - return self._generic_walker(self._valid_shapes_tensor) + return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True) @classmethod def _valid_shapes_tensor(cls, inputs, ds): @@ -221,7 +223,9 @@ def _valid_shapes_tensor(cls, inputs, ds): issues[i] = f"d=[{d}]" return issues if issues else None - def _generic_walker(self, processor: Callable, args_kwargs: bool = False): + def _generic_walker( + self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False + ): """ Generic deserializator walking through inputs and dynamic_shapes all along. The function returns a result with the same structure as the dynamic shapes. @@ -231,7 +235,12 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False): f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - res = self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) + res = self._generic_walker_step( + processor, + self.kwargs, + self.dynamic_shapes, + flatten_unflatten=flatten_unflatten, + ) return (tuple(), res) if args_kwargs else res if not self.kwargs: @@ -239,7 +248,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False): f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - res = self._generic_walker_step(processor, self.args, self.dynamic_shapes) + res = self._generic_walker_step( + processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten + ) return (res, {}) if args_kwargs else res assert isinstance(self.dynamic_shapes, dict), ( @@ -250,12 +261,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False): self.dynamic_shapes ): # No dynamic shapes for the positional arguments. - return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) + return self._generic_walker_step( + processor, + self.kwargs, + self.dynamic_shapes, + flatten_unflatten=flatten_unflatten, + ) if isinstance(self.args_names, list): if not set(self.args_names) & set(self.dynamic_shapes): # No dynamic shapes for the positional arguments. - return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) + return self._generic_walker_step( + processor, + self.kwargs, + self.dynamic_shapes, + flatten_unflatten=flatten_unflatten, + ) assert self.args_names, ( "args and kwargs are filled, then args_names must be specified in " @@ -268,7 +289,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False): ) kwargs = dict(zip(self.args_names, self.args)) kwargs.update(self.kwargs) - res = self._generic_walker_step(processor, kwargs, self.dynamic_shapes) + res = self._generic_walker_step( + processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten + ) if args_kwargs: pgs = [None for _ in range(len(self.args))] kws = {} @@ -286,7 +309,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False): ) @classmethod - def _generic_walker_step(cls, processor: Callable, inputs, ds): + def _generic_walker_step( + cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False + ): if isinstance(inputs, torch.Tensor): return processor(inputs, ds) if isinstance(inputs, (int, float, str)): @@ -303,7 +328,11 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds): if isinstance(inputs, (tuple, list)): value = [] for i, d in zip(inputs, ds): - value.append(cls._generic_walker_step(processor, i, d)) + value.append( + cls._generic_walker_step( + processor, i, d, flatten_unflatten=flatten_unflatten + ) + ) return ( (value if isinstance(ds, list) else tuple(value)) if any(v is not None for v in value) @@ -314,7 +343,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds): ), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}" dvalue = {} for k, v in inputs.items(): - t = cls._generic_walker_step(processor, v, ds[k]) + t = cls._generic_walker_step( + processor, v, ds[k], flatten_unflatten=flatten_unflatten + ) if t is not None: dvalue[k] = t return dvalue if dvalue else None @@ -325,11 +356,18 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds): f"torch.utils._pytree.register_pytree_node, it is not possible to " f"map this class with the given dynamic shapes." ) + if flatten_unflatten: + flatunflat = flatten_unflatten_for_dynamic_shapes(inputs) + return cls._generic_walker_step( + processor, flatunflat, ds, flatten_unflatten=flatten_unflatten + ) flat, _spec = torch.utils._pytree.tree_flatten(inputs) if all(isinstance(t, torch.Tensor) for t in flat): # We need to flatten dynamic shapes as well ds = flatten_dynamic_shapes(ds) - return cls._generic_walker_step(processor, flat, ds) + return cls._generic_walker_step( + processor, flat, ds, flatten_unflatten=flatten_unflatten + ) class ChangeDimensionProcessor: def __init__(self, desired_values): diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 5c3a511c..db399801 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -98,7 +98,8 @@ def string_type( with_min_max: bool = False, with_device: bool = False, ignore: bool = False, - limit: int = 10, + limit: int = 20, + verbose: int = 0, ) -> str: """ Displays the types of an object as a string. @@ -108,6 +109,7 @@ def string_type( :param with_min_max: displays information about the values :param with_device: display the device :param ignore: if True, just prints the type for unknown types + :param verbose: verbosity (to show the path it followed to get that print) :return: str .. runpython:: @@ -140,19 +142,9 @@ def string_type( print(string_type(inputs, with_shape=True, with_min_max=True)) """ if obj is None: + if verbose: + print(f"[string_type] A:{type(obj)}") return "None" - if is_dataclass(obj): - values = {f.name: getattr(obj, f.name, None) for f in fields(obj)} - values = {k: v for k, v in values.items() if v is not None} - s = string_type( - values, - with_shape=with_shape, - with_min_max=with_min_max, - with_device=with_device, - ignore=ignore, - limit=limit, - ) - return f"{obj.__class__.__name__}{s[4:]}" # tuple if isinstance(obj, tuple): @@ -164,7 +156,10 @@ def string_type( with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] C:{type(obj)}") return f"({s},)" if len(obj) < limit: js = ",".join( @@ -175,9 +170,12 @@ def string_type( with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) for o in obj ) + if verbose: + print(f"[string_type] D:{type(obj)}") return f"({js})" tt = string_type( obj[0], @@ -186,10 +184,15 @@ def string_type( with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) + if verbose: + print(f"[string_type] E:{type(obj)}") return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]" + if verbose: + print(f"[string_type] F:{type(obj)}") return f"#{len(obj)}({tt},...)" # list if isinstance(obj, list): @@ -202,9 +205,12 @@ def string_type( with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) for o in obj ) + if verbose: + print(f"[string_type] G:{type(obj)}") return f"#{len(obj)}[{js}]" tt = string_type( obj[0], @@ -213,10 +219,15 @@ def string_type( with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) + if verbose: + print(f"[string_type] H:{type(obj)}") return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]" + if verbose: + print(f"[string_type] I:{type(obj)}") return f"#{len(obj)}[{tt},...]" # set if isinstance(obj, set): @@ -229,30 +240,70 @@ def string_type( with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) for o in obj ) + if verbose: + print(f"[string_type] J:{type(obj)}") return f"{{{js}}}" if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) + if verbose: + print(f"[string_type] K:{type(obj)}") return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]" + if verbose: + print(f"[string_type] L:{type(obj)}") return f"{{...}}#{len(obj)}" if with_shape else "{...}" # dict - if isinstance(obj, dict): + if isinstance(obj, dict) and type(obj) is dict: if len(obj) == 0: + if verbose: + print(f"[string_type] M:{type(obj)}") return "{}" + + import torch + + if all(isinstance(k, int) for k in obj) and all( + isinstance( + v, + ( + str, + torch.export.dynamic_shapes._Dim, + torch.export.dynamic_shapes._DerivedDim, + torch.export.dynamic_shapes._DimHint, + ), + ) + for v in obj.values() + ): + # This is dynamic shapes + rows = [] + for k, v in obj.items(): + if isinstance(v, str): + rows.append(f"{k}:DYN({v})") + else: + rows.append(f"{k}:{string_type(v, verbose=verbose)}") + if verbose: + print(f"[string_type] DS0:{type(obj)}") + return f"{{{','.join(rows)}}}" + kws = dict( with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, + verbose=verbose, ) s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items()) if all(isinstance(k, int) for k in obj): + if verbose: + print(f"[string_type] N:{type(obj)}") return f"{{{s}}}" + if verbose: + print(f"[string_type] O:{type(obj)}") return f"dict({s})" - # arrat + # array if isinstance(obj, np.ndarray): from .onnx_helper import np_dtype_to_tensor_dtype @@ -267,25 +318,113 @@ def string_type( nob = obj.ravel() nob = nob[~np.isnan(nob)] if nob.size == 0: + if verbose: + print(f"[string_type] A1:{type(obj)}") return f"{s}[N{n_nan}nans]" + if verbose: + print(f"[string_type] A2:{type(obj)}") return f"{s}[{nob.min()},{nob.max()}:A{nob.astype(float).mean()}N{n_nan}nans]" + if verbose: + print(f"[string_type] A3:{type(obj)}") return f"{s}[{obj.min()},{obj.max()}:A{obj.astype(float).mean()}]" i = np_dtype_to_tensor_dtype(obj.dtype) if not with_shape: + if verbose: + print(f"[string_type] A4:{type(obj)}") return f"A{i}r{len(obj.shape)}" + if verbose: + print(f"[string_type] A5:{type(obj)}") return f"A{i}s{'x'.join(map(str, obj.shape))}" import torch # Dim, SymInt if isinstance(obj, torch.export.dynamic_shapes._DerivedDim): + if verbose: + print(f"[string_type] Y1:{type(obj)}") return "DerivedDim" if isinstance(obj, torch.export.dynamic_shapes._Dim): + if verbose: + print(f"[string_type] Y2:{type(obj)}") return f"Dim({obj.__name__})" if isinstance(obj, torch.SymInt): + if verbose: + print(f"[string_type] Y3:{type(obj)}") return "SymInt" if isinstance(obj, torch.SymFloat): + if verbose: + print(f"[string_type] Y4:{type(obj)}") return "SymFloat" + + if isinstance(obj, torch.export.dynamic_shapes._DimHint): + cl = ( + torch.export.dynamic_shapes._DimHintType + if hasattr(torch.export.dynamic_shapes, "_DimHintType") + else torch.export.Dim + ) + if obj in (torch.export.Dim.DYNAMIC, cl.DYNAMIC): + if verbose: + print(f"[string_type] Y8:{type(obj)}") + return "DYNAMIC" + if obj in (torch.export.Dim.AUTO, cl.AUTO): + if verbose: + print(f"[string_type] Y9:{type(obj)}") + return "AUTO" + if verbose: + print(f"[string_type] Y7:{type(obj)}") + return str(obj) + + if isinstance(obj, bool): + if with_min_max: + if verbose: + print(f"[string_type] W1:{type(obj)}") + return f"bool={obj}" + if verbose: + print(f"[string_type] W2:{type(obj)}") + return "bool" + if isinstance(obj, int): + if with_min_max: + if verbose: + print(f"[string_type] W3:{type(obj)}") + return f"int={obj}" + if verbose: + print(f"[string_type] W4:{type(obj)}") + return "int" + if isinstance(obj, float): + if with_min_max: + if verbose: + print(f"[string_type] W6:{type(obj)}") + return f"float={obj}" + if verbose: + print(f"[string_type] W8:{type(obj)}") + return "float" + if isinstance(obj, str): + if verbose: + print(f"[string_type] W9:{type(obj)}") + return "str" + if isinstance(obj, slice): + if verbose: + print(f"[string_type] W10:{type(obj)}") + return "slice" + + if is_dataclass(obj): + # That includes torch.export.Dim.AUTO, torch.export.Dim.DYNAMIC so they need to be + # handled before that. + values = {f.name: getattr(obj, f.name, None) for f in fields(obj)} + values = {k: v for k, v in values.items() if v is not None} + s = string_type( + values, + with_shape=with_shape, + with_min_max=with_min_max, + with_device=with_device, + ignore=ignore, + limit=limit, + verbose=verbose, + ) + if verbose: + print(f"[string_type] B:{type(obj)}") + return f"{obj.__class__.__name__}{s[4:]}" + # Tensors if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor): from .onnx_helper import torch_dtype_to_onnx_dtype @@ -293,7 +432,11 @@ def string_type( i = torch_dtype_to_onnx_dtype(obj.dtype) prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" if not with_shape: + if verbose: + print(f"[string_type] F1:{type(obj)}") return f"{prefix}F{i}r{len(obj.shape)}" + if verbose: + print(f"[string_type] F2:{type(obj)}") return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}" if isinstance(obj, torch.Tensor): from .onnx_helper import torch_dtype_to_onnx_dtype @@ -301,67 +444,73 @@ def string_type( if with_min_max: s = string_type(obj, with_shape=with_shape, with_device=with_device) if len(obj.shape) == 0: + if verbose: + print(f"[string_type] T1:{type(obj)}") return f"{s}={obj}" if obj.numel() == 0: + if verbose: + print(f"[string_type] T2:{type(obj)}") return f"{s}[empty]" n_nan = obj.reshape((-1,)).isnan().to(int).sum() if n_nan > 0: nob = obj.reshape((-1,)) nob = nob[~nob.isnan()] if obj.dtype in {torch.complex64, torch.complex128}: + if verbose: + print(f"[string_type] T3:{type(obj)}") return ( f"{s}[{nob.abs().min()},{nob.abs().max():A{nob.mean()}N{n_nan}nans}]" ) + if verbose: + print(f"[string_type] T5:{type(obj)}") return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}N{n_nan}nans]" if obj.dtype in {torch.complex64, torch.complex128}: + if verbose: + print(f"[string_type] T6:{type(obj)}") return f"{s}[{obj.abs().min()},{obj.abs().max()}:A{obj.abs().mean()}]" + if verbose: + print(f"[string_type] T7:{type(obj)}") return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}]" i = torch_dtype_to_onnx_dtype(obj.dtype) prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" if not with_shape: + if verbose: + print(f"[string_type] T8:{type(obj)}") return f"{prefix}T{i}r{len(obj.shape)}" + if verbose: + print(f"[string_type] T9:{type(obj)}") return f"{prefix}T{i}s{'x'.join(map(str, obj.shape))}" if obj.__class__.__name__ == "OrtValue": if not obj.has_value(): + if verbose: + print(f"[string_type] V1:{type(obj)}") return "OV()" if not obj.is_tensor(): + if verbose: + print(f"[string_type] V2:{type(obj)}") return "OV(NOTENSOR)" if with_min_max: try: t = obj.numpy() except Exception: # pass unable to convert into numpy (bfloat16, ...) + if verbose: + print(f"[string_type] V3:{type(obj)}") return "OV(NO-NUMPY:FIXIT)" + if verbose: + print(f"[string_type] V4:{type(obj)}") return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})" dt = obj.element_type() shape = obj.shape() if with_shape: + if verbose: + print(f"[string_type] V5:{type(obj)}") return f"OV{dt}s{'x'.join(map(str, shape))}" + if verbose: + print(f"[string_type] V6:{type(obj)}") return f"OV{dt}r{len(shape)}" - if isinstance(obj, bool): - if with_min_max: - return f"bool={obj}" - return "bool" - if isinstance(obj, int): - if with_min_max: - return f"int={obj}" - return "int" - if isinstance(obj, float): - if with_min_max: - return f"float={obj}" - return "float" - if isinstance(obj, str): - return "str" - if isinstance(obj, slice): - return "slice" - - if obj == torch.export.Dim.DYNAMIC: - return "DYNAMIC" - if obj == torch.export.Dim.AUTO: - return "AUTO" - # others classes if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES: @@ -374,13 +523,20 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] DS:{type(obj)}") return f"{obj.__class__.__name__}[serialized]({att})" if type(obj).__name__ == "Node" and hasattr(obj, "meta"): # torch.fx.node.Node + if verbose: + print(f"[string_type] TT1:{type(obj)}") return f"%{obj.target}" if type(obj).__name__ == "ValueInfoProto": + if verbose: + print(f"[string_type] OO1:{type(obj)}") return f"OT{obj.type.tensor_type.elem_type}" if obj.__class__.__name__ == "BatchFeature": @@ -390,7 +546,10 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] TT2:{type(obj)}") return f"BatchFeature(data={s})" if obj.__class__.__name__ == "BatchEncoding": @@ -400,25 +559,31 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] TT3:{type(obj)}") return f"BatchEncoding(data={s})" if obj.__class__.__name__ == "VirtualTensor": + if verbose: + print(f"[string_type] TT4:{type(obj)}") return ( f"{obj.__class__.__name__}(name={obj.name!r}, " f"dtype={obj.dtype}, shape={obj.shape})" ) - if obj.__class__.__name__ in ("_DimHint", "_DimHintType"): - return str(obj) - if isinstance(obj, torch.nn.Module): + if verbose: + print(f"[string_type] MM:{type(obj)}") return f"{obj.__class__.__name__}(...)" if isinstance(obj, (torch.device, torch.dtype, torch.memory_format, torch.layout)): + if verbose: + print(f"[string_type] TT7:{type(obj)}") return f"{obj.__class__.__name__}({obj})" - if isinstance( + if isinstance( # TreeSpec, MappingKey, SequenceKey obj, ( torch.utils._pytree.TreeSpec, @@ -426,17 +591,20 @@ def string_type( torch.utils._pytree.SequenceKey, ), ): + if verbose: + print(f"[string_type] TT8:{type(obj)}") return repr(obj).replace(" ", "").replace("\n", " ") # to avoid failures - if type(obj).__name__ == "MambaCache": + if obj.__class__.__name__ == "MambaCache": c = string_type( obj.conv_states, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) d = string_type( obj.ssm_states, @@ -444,7 +612,10 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] CACHE1:{type(obj)}") return f"MambaCache(conv_states={c}, ssm_states={d})" if obj.__class__.__name__ == "DynamicCache": @@ -454,6 +625,7 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) vc = string_type( obj.value_cache, @@ -461,7 +633,10 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] CACHE2:{type(obj)}") return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})" if obj.__class__.__name__ == "EncoderDecoderCache": @@ -471,6 +646,7 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) cross = string_type( obj.cross_attention_cache, @@ -478,15 +654,22 @@ def string_type( with_min_max=with_min_max, with_device=with_device, limit=limit, + verbose=verbose, ) + if verbose: + print(f"[string_type] CACHE3:{type(obj)}") return ( f"{obj.__class__.__name__}(self_attention_cache={att}, " f"cross_attention_cache={cross})" ) if ignore: + if verbose: + print(f"[string_type] CACHE4:{type(obj)}") return f"{obj.__class__.__name__}(...)" + if verbose: + print(f"[string_type] END:{type(obj)}") raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}") diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 60d7cfd0..111d9169 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -52,19 +52,19 @@ def get_inputs( :return: dictionary """ batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("seq_length", min=1, max=4096) - cache_length = torch.export.Dim("cache_length", min=1, max=4096) - images = torch.export.Dim("images", min=1, max=4096) + seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) + cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) + images = "images" # torch.export.Dim("images", min=1, max=4096) shapes = { "input_ids": {0: batch, 1: seq_length}, "attention_mask": { 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + 1: "cache+seq", # cache_length + seq_length }, "position_ids": { 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + 1: "cache+seq", # cache_length + seq_length }, "past_key_values": [ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index abce3714..3dea707c 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -59,14 +59,14 @@ def get_inputs( encoder_outputs:dict(last_hidden_state:T1s1x16x512) """ batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("seq_length", min=1, max=4096) - cache_length = torch.export.Dim("cache_length", min=1, max=4096) - cache_length2 = torch.export.Dim("cache_length2", min=1, max=4096) + seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) + cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096) + cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096) shapes = { "input_ids": {0: batch, 1: seq_length}, - "decoder_input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC}, - "attention_mask": {0: batch, 1: torch.export.Dim.DYNAMIC}, + "decoder_input_ids": {0: batch, 1: "seq_ids"}, + "attention_mask": {0: batch, 1: "seq_mask"}, # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC}, "past_key_values": [ [ diff --git a/onnx_diagnostic/tasks/text_classification.py b/onnx_diagnostic/tasks/text_classification.py index 37d262f0..810274cc 100644 --- a/onnx_diagnostic/tasks/text_classification.py +++ b/onnx_diagnostic/tasks/text_classification.py @@ -35,7 +35,7 @@ def get_inputs( attention_mask:T7s1x13[1,1:A1.0]) """ batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("sequence_length", min=1, max=1024) + seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024) shapes = { "input_ids": {0: batch, 1: seq_length}, "token_type_ids": {0: batch, 1: seq_length}, diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index eb81bc78..7ce039ed 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -84,8 +84,8 @@ def get_inputs( :return: dictionary """ batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("seq_length", min=1, max=4096) - cache_length = torch.export.Dim("cache_length", min=1, max=4096) + seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) + cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) if config is not None and config.__class__.__name__ == "FalconMambaConfig": seq_length_multiple = 8 @@ -101,11 +101,11 @@ def get_inputs( "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC}, "attention_mask": { 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + 1: "cache+seq", # cache_length + seq_length }, "cache_position": { 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + 1: "cache+seq", # cache_length + seq_length }, "cache_params": [ [{0: batch} for _ in range(num_hidden_layers)], @@ -145,11 +145,11 @@ def get_inputs( "input_ids": {0: batch, 1: seq_length}, "attention_mask": { 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + 1: "cache+seq", # cache_length + seq_length }, "position_ids": { 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + 1: "cache+seq", # cache_length + seq_length }, "past_key_values": [ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], diff --git a/onnx_diagnostic/torch_export_patches/__init__.py b/onnx_diagnostic/torch_export_patches/__init__.py index a67f064d..ff978ee3 100644 --- a/onnx_diagnostic/torch_export_patches/__init__.py +++ b/onnx_diagnostic/torch_export_patches/__init__.py @@ -2,110 +2,3 @@ bypass_export_some_errors, register_additional_serialization_functions, ) - -""" --- Missing dependencies -- - -def is_torchdynamo_exporting() -> bool: - "Tells if torch is exporting a model." - import torch - - if not hasattr(torch.compiler, "is_exporting"): - # torch.compiler.is_exporting requires torch>=2.7 - return False - - try: - return torch.compiler.is_exporting() - except Exception: - try: - import torch._dynamo as dynamo - - return dynamo.is_exporting() # type: ignore - except Exception: - return False - - -def string_type(anything, **args): - # too long - # from onnx_diagnostic.helpers import string_type - return str(anything) - - -if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): - - def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> transformers.cache_utils.DynamicCache: - ''' - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. - This version is valid for ``transformers >= 4.50``. - - :param key_value_pairs: list of pairs of (key, values) - :return: :class:`transformers.cache_utils.DynamicCache` - - Example: - - :: - - n_layers = 2 - bsize, nheads, slen, dim = 2, 4, 3, 7 - - past_key_values = make_dynamic_cache( - [ - ( - torch.randn(bsize, nheads, slen, dim), - torch.randn(bsize, nheads, slen, dim), - ) - for i in range(n_layers) - ] - ) - print(string_type(past_key_values, with_shape=True)) - ''' - return transformers.cache_utils.DynamicCache(key_value_pairs) - -else: - - def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> transformers.cache_utils.DynamicCache: - ''' - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. - This version is valid for ``transformers < 4.50``. - - :param key_value_pairs: list of pairs of (key, values) - :return: :class:`transformers.cache_utils.DynamicCache` - - Example: - - :: - - n_layers = 2 - bsize, nheads, slen, dim = 2, 4, 3, 7 - - past_key_values = make_dynamic_cache( - [ - ( - torch.randn(bsize, nheads, slen, dim), - torch.randn(bsize, nheads, slen, dim), - ) - for i in range(n_layers) - ] - ) - print(string_type(past_key_values, with_shape=True)) - ''' - cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) - for i, (key, value) in enumerate(key_value_pairs): - cache.update(key, value, i) - return cache - - -def make_encoder_decoder_cache( - self_attention_cache: transformers.cache_utils.DynamicCache, - cross_attention_cache: transformers.cache_utils.DynamicCache, -) -> transformers.cache_utils.EncoderDecoderCache: - "Creates an EncoderDecoderCache." - return transformers.cache_utils.EncoderDecoderCache( - self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache - ) - -""" diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index d42d2534..dcb52b7d 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -193,9 +193,8 @@ def flatten_mamba_cache( ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" flat = [ - (k, getattr(mamba_cache, k)) - for k in ["conv_states", "ssm_states"] - if hasattr(mamba_cache, k) + ("conv_states", mamba_cache.conv_states), + ("ssm_states", mamba_cache.ssm_states), ] return [f[1] for f in flat], [f[0] for f in flat] @@ -251,11 +250,7 @@ def flatten_dynamic_cache( """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"): return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache) - flat = [ - (k, getattr(dynamic_cache, k)) - for k in ["key_cache", "value_cache"] - if hasattr(dynamic_cache, k) - ] + flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)] return [f[1] for f in flat], [f[0] for f in flat] diff --git a/onnx_diagnostic/torch_export_patches/patch_inputs.py b/onnx_diagnostic/torch_export_patches/patch_inputs.py index 8cba5ac4..2c46a0d3 100644 --- a/onnx_diagnostic/torch_export_patches/patch_inputs.py +++ b/onnx_diagnostic/torch_export_patches/patch_inputs.py @@ -119,6 +119,10 @@ def convert_dynamic_axes_into_dynamic_shapes( ) changes[k] = type(updated_kwargs[k]) continue + if isinstance(v, transformers.cache_utils.DynamicCache): + updated_kwargs[k] = [v.key_cache, v.value_cache] + changes[k] = type(v) + continue raise NotImplementedError( f"Unexpected type {type(v)} for parameter {k!r} " f"({string_type(v, with_shape=True)})" @@ -132,6 +136,13 @@ def convert_dynamic_axes_into_dynamic_shapes( if k not in changes and k in updated_kwargs and isinstance(v, dict): dynamic_shapes[k] = v continue + if ( + k in updated_kwargs + and k in changes + and changes[k] == transformers.cache_utils.DynamicCache + ): + dynamic_shapes[k] = v + continue if "." in k: # something like present.0.key prefix = k.split(".")[0] @@ -172,3 +183,21 @@ def convert_dynamic_axes_into_dynamic_shapes( ) return (), updated_kwargs, dynamic_shapes + + +def use_dyn_not_str(dynamic_shapes: Any) -> Any: + """ + Some functions returns dynamic shapes as string. + This functions replaces them with ``torch.export.Dim.DYNAMIC``. + """ + if isinstance(dynamic_shapes, list): + return [use_dyn_not_str(a) for a in dynamic_shapes] + if isinstance(dynamic_shapes, tuple): + return tuple(use_dyn_not_str(a) for a in dynamic_shapes) + if isinstance(dynamic_shapes, dict): + return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()} + if isinstance(dynamic_shapes, set): + return {use_dyn_not_str(a) for a in dynamic_shapes} + if isinstance(dynamic_shapes, str): + return torch.export.Dim.DYNAMIC + return dynamic_shapes diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 9fcdc559..36af6d0b 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -10,6 +10,7 @@ from ..helpers.helper import flatten_object from ..helpers.rt_helper import make_feeds from ..helpers.torch_test_helper import to_any, torch_deepcopy +from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes from ..torch_export_patches import bypass_export_some_errors from .hghub import get_untrained_model_with_inputs from .hghub.model_inputs import random_input_kwargs @@ -24,10 +25,6 @@ def empty(value: Any) -> bool: return False -def _ds_clean(v): - return string_type(v) - - def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, Any]: """ Returns dummy inputs for a specific task. @@ -287,7 +284,7 @@ def validate_model( print(f"[validate_model] current inputs: {string_type(data['inputs'])}") print( f"[validate_model] current dynnamic_shapes: " - f"{_ds_clean(data['dynamic_shapes'])}" + f"{string_type(data['dynamic_shapes'])}" ) data["inputs"], data["dynamic_shapes"] = filter_inputs( data["inputs"], @@ -297,7 +294,7 @@ def validate_model( ) if verbose: print(f"[validate_model] new inputs: {string_type(data['inputs'])}") - print(f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}") + print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}") if not empty(dtype): if isinstance(dtype, str): @@ -319,7 +316,7 @@ def validate_model( for k in ["task", "size", "n_weights"]: summary[f"model_{k.replace('_','')}"] = data[k] summary["model_inputs"] = string_type(data["inputs"], with_shape=True) - summary["model_shapes"] = _ds_clean(str(data["dynamic_shapes"])) + summary["model_shapes"] = string_type(str(data["dynamic_shapes"])) summary["model_class"] = data["model"].__class__.__name__ summary["model_config_class"] = data["configuration"].__class__.__name__ summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") @@ -333,7 +330,7 @@ def validate_model( for k, v in data["inputs"].items(): print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}") for k, v in data["dynamic_shapes"].items(): - print(f"[validate_model] +SHAPE {k}={_ds_clean(v)}") + print(f"[validate_model] +SHAPE {k}={string_type(v)}") print("[validate_model] --") if do_run: @@ -646,6 +643,18 @@ def call_torch_export_export( strict = "nostrict" not in exporter args, kwargs = split_args_kwargs(data["inputs_export"]) ds = data.get("dynamic_shapes", None) + + summary["export_exporter"] = exporter + summary["export_optimization"] = optimization or "" + summary["export_strict"] = strict + summary["export_args"] = string_type(args, with_shape=True) + summary["export_kwargs"] = string_type(kwargs, with_shape=True) + summary["export_dynamic_shapes"] = string_type(ds) + + # There is an issue with DynamicShape [[],[]] becomes [] + dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by() + summary["export_dynamic_shapes_export_export"] = string_type(dse) + if verbose: print( f"[call_torch_export_export] exporter={exporter!r}, " @@ -653,19 +662,15 @@ def call_torch_export_export( ) print(f"[call_torch_export_export] args={string_type(args, with_shape=True)}") print(f"[call_torch_export_export] kwargs={string_type(kwargs, with_shape=True)}") - print(f"[call_torch_export_export] dynamic_shapes={_ds_clean(ds)}") + print(f"[call_torch_export_export] dynamic_shapes={string_type(ds)}") + print(f"[call_torch_export_export] dynamic_shapes_export_export={string_type(dse)}") print("[call_torch_export_export] export...") - summary["export_exporter"] = exporter - summary["export_optimization"] = optimization or "" - summary["export_strict"] = strict - summary["export_args"] = string_type(args, with_shape=True) - summary["export_kwargs"] = string_type(kwargs, with_shape=True) begin = time.perf_counter() if quiet: try: ep = torch.export.export( - data["model"], args, kwargs=kwargs, dynamic_shapes=ds, strict=strict + data["model"], args, kwargs=kwargs, dynamic_shapes=dse, strict=strict ) except Exception as e: summary["ERR_export_export"] = str(e) @@ -674,7 +679,7 @@ def call_torch_export_export( return summary, data else: ep = torch.export.export( - data["model"], args, kwargs=kwargs, dynamic_shapes=ds, strict=strict + data["model"], args, kwargs=kwargs, dynamic_shapes=dse, strict=strict ) summary["time_export_export"] = time.perf_counter() - begin @@ -887,7 +892,7 @@ def call_torch_export_onnx( ) print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}") print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}") - print(f"[call_torch_export_onnx] dynamic_shapes={_ds_clean(ds)}") + print(f"[call_torch_export_onnx] dynamic_shapes={string_type(ds)}") print("[call_torch_export_onnx] export...") summary["export_exporter"] = exporter summary["export_optimization"] = optimization or "" @@ -895,15 +900,30 @@ def call_torch_export_onnx( summary["export_args"] = string_type(args, with_shape=True) summary["export_kwargs"] = string_type(kwargs, with_shape=True) - export_export_kwargs = ( - dict(dynamo=True, dynamic_shapes=ds) - if dynamo - else dict( + if dynamo: + export_export_kwargs = dict(dynamo=True, dynamic_shapes=ds) + else: + export_export_kwargs = dict( dynamo=False, - dynamic_axes=CoupleInputsDynamicShapes(args, kwargs, ds).replace_by_string(), + dynamic_axes={ + k: v + for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) + .replace_by_string() + .items() + if isinstance(v, dict) + }, + ) + args = tuple(flatten_unflatten_for_dynamic_shapes(a) for a in args) + kwargs = {k: flatten_unflatten_for_dynamic_shapes(v) for k, v in kwargs.items()} + if verbose: + print("[call_torch_export_onnx] dynamo=False so...") + print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}") + print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}") + if verbose: + print( + f"[call_torch_export_onnx] export_export_kwargs=" + f"{string_type(export_export_kwargs, with_shape=True)}" ) - ) - begin = time.perf_counter() if quiet: try: @@ -1005,7 +1025,7 @@ def call_torch_export_custom( ) print(f"[call_torch_export_custom] args={string_type(args, with_shape=True)}") print(f"[call_torch_export_custom] kwargs={string_type(kwargs, with_shape=True)}") - print(f"[call_torch_export_custom] dynamic_shapes={_ds_clean(ds)}") + print(f"[call_torch_export_custom] dynamic_shapes={string_type(ds)}") print("[call_torch_export_custom] export...") summary["export_exporter"] = exporter summary["export_optimization"] = optimization or ""