diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 02dfc4c0..f7c2bd8c 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.7.1 ++++++ + +* :pr:`152`: add a function to compute fully dynamic shapes given any inputs + 0.7.0 +++++ diff --git a/_doc/api/export/index.rst b/_doc/api/export/index.rst index a46e7742..eeba2d35 100644 --- a/_doc/api/export/index.rst +++ b/_doc/api/export/index.rst @@ -6,6 +6,7 @@ onnx_diagnostic.export :caption: modules dynamic_shapes + shape_helper validate CoupleInputsDynamicShapes diff --git a/_doc/api/export/shape_helper.rst b/_doc/api/export/shape_helper.rst new file mode 100644 index 00000000..78ccb5cd --- /dev/null +++ b/_doc/api/export/shape_helper.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.export.shape_helper +=================================== + +.. automodule:: onnx_diagnostic.export.shape_helper + :members: + :no-undoc-members: diff --git a/_doc/index.rst b/_doc/index.rst index 11f3a9fc..a5e3b526 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -211,8 +211,9 @@ The function replaces dynamic dimensions defined as strings by ``torch.export.Dim.DYNAMIC``. Older versions -++++++++++++++ +============== +* `0.7.1 <../v0.7.1/index.html>`_ * `0.7.0 <../v0.7.0/index.html>`_ * `0.6.3 <../v0.6.3/index.html>`_ * `0.5.0 <../v0.5.0/index.html>`_ diff --git a/_doc/recipes/plot_dynamic_shapes_what.py b/_doc/recipes/plot_dynamic_shapes_what.py new file mode 100644 index 00000000..d0c5eefd --- /dev/null +++ b/_doc/recipes/plot_dynamic_shapes_what.py @@ -0,0 +1,78 @@ +""" +Builds dynamic shapes from any input +==================================== + +Getting dynamic shapes right for :func:`torch.export.export` when the inputs +includes a custom class such as a :class:`transformers.cache_utils.DynamicCache`. +:func:`torch.export.export` cannot use a DynamicCache filled with dynamic shapes +but instead it uses a kind of unserialized serialized form of it. + +Standard inputs for a LLM with a dynamic cache +++++++++++++++++++++++++++++++++++++++++++++++ +""" + +import pprint +import torch +from onnx_diagnostic import doc +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache +from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches + +bsize, nheads, slen, dim = 2, 1, 30, 96 + +inputs = dict( + input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), + attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), + position_ids=torch.arange(3, dtype=torch.int64), + past_key_values=make_dynamic_cache( + [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] + ), +) + +print(string_type(inputs, with_shape=True)) + +# %% +# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs` +# produces the corresponding dynamic shapes assuming they are all dynamic. +ds = all_dynamic_shape_from_inputs(inputs) +pprint.pprint(ds) + +# %% +# What about a StaticCache? +# +++++++++++++++++++++++++ +# +# We use :func:`onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs` to get +# a consistent configuration with a static cache. + +data = get_untrained_model_with_inputs( + "arnir0/Tiny-LLM", + model_kwargs=dict(cache_implementation="static"), + inputs_kwargs=dict(cls_cache="StaticCache"), +) +inputs = data["inputs"] +print(string_type(inputs, with_shape=True)) + +# %% +# And the input shapes. +ds = all_dynamic_shape_from_inputs(inputs) +if ds["past_key_values"]: + print("transformers implemented serialization function for StaticCache.") +else: + print("We need to use serialization function implemented in this package.") + with torch_export_patches(patch_transformers=True): + ds = all_dynamic_shape_from_inputs(inputs) + +# %% +# That gives. +pprint.pprint(ds) + +# %% +# We can compare with the ones returned by the function. +pprint.pprint(data["dynamic_shapes"]) + + +# %% + +doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green") diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py new file mode 100644 index 00000000..69c87399 --- /dev/null +++ b/_unittests/ut_export/test_shape_helper.py @@ -0,0 +1,46 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch +from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + +class TestShapeHelper(ExtTestCase): + @requires_transformers("4.52") + @requires_torch("2.7.99") + def test_all_dynamic_shape_from_inputs(self): + ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6)))) + self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds) + ds = all_dynamic_shape_from_inputs( + (torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO + ) + self.assertEqual( + [ + {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}, + {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}, + ], + ds, + ) + + @requires_transformers("4.52") + @requires_torch("2.7.99") + def test_all_dynamic_shape_from_inputs_dynamic_cache(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + print(self.string_type(data["inputs"], with_shape=True)) + ds = all_dynamic_shape_from_inputs(data["inputs"]) + self.assertEqual( + { + "input_ids": {0: "d_0_0", 1: "d_0_1"}, + "attention_mask": {0: "d_1_0", 1: "d_1_1"}, + "position_ids": {0: "d_2_0", 1: "d_2_1"}, + "past_key_values": { + "key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}], + "value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}], + }, + }, + ds, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 98abae93..bd37424e 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.7.0" +__version__ = "0.7.1" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py new file mode 100644 index 00000000..98c91025 --- /dev/null +++ b/onnx_diagnostic/export/shape_helper.py @@ -0,0 +1,49 @@ +from typing import Any, Set +from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes + + +def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: + """ + Returns the dynamic shapes for the given inputs. + All dimensions are considered as dynamic. + ``dim_prefix`` can be a string (the function uses it as a prefix), + or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``. + + .. runpython:: + :showcode: + + import pprint + import torch + from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache + from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs + + bsize, nheads, slen, dim = 2, 1, 30, 96 + inputs = dict( + input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), + attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), + position_ids=torch.arange(3, dtype=torch.int64), + past_key_values=make_dynamic_cache( + [(torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim))] + ), + ) + ds = all_dynamic_shape_from_inputs(inputs) + pprint.pprint(ds) + """ + if isinstance(dim_prefix, str): + prefixes: Set[str] = set() + + def tensor_to_shape(tensor): + n = len(prefixes) + p = f"{dim_prefix}_{n}" + prefixes.add(p) + return {i: f"{p}_{i}" for i in range(tensor.ndim)} + + else: + + def tensor_to_shape(tensor): + return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420 + + return flatten_unflatten_for_dynamic_shapes( + inputs, change_function=tensor_to_shape, use_dict=True + ) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index e37a1e2d..9b1ba31c 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -1,11 +1,15 @@ -from typing import Any, List, Tuple +from typing import Any, Callable, List, Optional, Tuple import packaging.version as pv import torch import transformers import transformers.cache_utils -def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any: +def flatten_unflatten_for_dynamic_shapes( + obj: Any, + use_dict: bool = False, + change_function: Optional[Callable[[torch.Tensor], Any]] = None, +) -> Any: """ Returns the object in a different structure similar to what the definition of the dynamic shapes should use. @@ -16,10 +20,12 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. + :param change_function: to modifies the tensor in the structure itself, + like replace them by a shape :return: the serialized object """ if isinstance(obj, torch.Tensor): - return obj + return change_function(obj) if change_function else obj flat, spec = torch.utils._pytree.tree_flatten(obj) start = 0 end = 0 @@ -27,7 +33,9 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An for subspec in spec.children_specs: end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) - value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict) + value = flatten_unflatten_for_dynamic_shapes( + value, use_dict=use_dict, change_function=change_function + ) subtrees.append(value) start = end if use_dict and (spec.type is dict or spec.context):