Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Change Logs
===========

0.7.1
+++++

* :pr:`152`: add a function to compute fully dynamic shapes given any inputs

0.7.0
+++++

Expand Down
1 change: 1 addition & 0 deletions _doc/api/export/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ onnx_diagnostic.export
:caption: modules

dynamic_shapes
shape_helper
validate

CoupleInputsDynamicShapes
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/export/shape_helper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.export.shape_helper
===================================

.. automodule:: onnx_diagnostic.export.shape_helper
:members:
:no-undoc-members:
3 changes: 2 additions & 1 deletion _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ The function replaces dynamic dimensions defined as strings by
``torch.export.Dim.DYNAMIC``.

Older versions
++++++++++++++
==============

* `0.7.1 <../v0.7.1/index.html>`_
* `0.7.0 <../v0.7.0/index.html>`_
* `0.6.3 <../v0.6.3/index.html>`_
* `0.5.0 <../v0.5.0/index.html>`_
Expand Down
78 changes: 78 additions & 0 deletions _doc/recipes/plot_dynamic_shapes_what.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Builds dynamic shapes from any input
====================================

Getting dynamic shapes right for :func:`torch.export.export` when the inputs
includes a custom class such as a :class:`transformers.cache_utils.DynamicCache`.
:func:`torch.export.export` cannot use a DynamicCache filled with dynamic shapes
but instead it uses a kind of unserialized serialized form of it.

Standard inputs for a LLM with a dynamic cache
++++++++++++++++++++++++++++++++++++++++++++++
"""

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches

bsize, nheads, slen, dim = 2, 1, 30, 96

inputs = dict(
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.arange(3, dtype=torch.int64),
past_key_values=make_dynamic_cache(
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
),
)

print(string_type(inputs, with_shape=True))

# %%
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
# produces the corresponding dynamic shapes assuming they are all dynamic.
ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)

# %%
# What about a StaticCache?
# +++++++++++++++++++++++++
#
# We use :func:`onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs` to get
# a consistent configuration with a static cache.

data = get_untrained_model_with_inputs(
"arnir0/Tiny-LLM",
model_kwargs=dict(cache_implementation="static"),
inputs_kwargs=dict(cls_cache="StaticCache"),
)
inputs = data["inputs"]
print(string_type(inputs, with_shape=True))

# %%
# And the input shapes.
ds = all_dynamic_shape_from_inputs(inputs)
if ds["past_key_values"]:
print("transformers implemented serialization function for StaticCache.")
else:
print("We need to use serialization function implemented in this package.")
with torch_export_patches(patch_transformers=True):
ds = all_dynamic_shape_from_inputs(inputs)

# %%
# That gives.
pprint.pprint(ds)

# %%
# We can compare with the ones returned by the function.
pprint.pprint(data["dynamic_shapes"])


# %%

doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green")
46 changes: 46 additions & 0 deletions _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs


class TestShapeHelper(ExtTestCase):
@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_inputs(self):
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
ds = all_dynamic_shape_from_inputs(
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
)
self.assertEqual(
[
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
],
ds,
)

@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
print(self.string_type(data["inputs"], with_shape=True))
ds = all_dynamic_shape_from_inputs(data["inputs"])
self.assertEqual(
{
"input_ids": {0: "d_0_0", 1: "d_0_1"},
"attention_mask": {0: "d_1_0", 1: "d_1_1"},
"position_ids": {0: "d_2_0", 1: "d_2_1"},
"past_key_values": {
"key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
"value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
},
},
ds,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Functions, classes to dig into a model when this one is right, slow, wrong...
"""

__version__ = "0.7.0"
__version__ = "0.7.1"
__author__ = "Xavier Dupré"
49 changes: 49 additions & 0 deletions onnx_diagnostic/export/shape_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Set
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes


def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
"""
Returns the dynamic shapes for the given inputs.
All dimensions are considered as dynamic.
``dim_prefix`` can be a string (the function uses it as a prefix),
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.

.. runpython::
:showcode:

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96
inputs = dict(
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.arange(3, dtype=torch.int64),
past_key_values=make_dynamic_cache(
[(torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim))]
),
)
ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)
"""
if isinstance(dim_prefix, str):
prefixes: Set[str] = set()

def tensor_to_shape(tensor):
n = len(prefixes)
p = f"{dim_prefix}_{n}"
prefixes.add(p)
return {i: f"{p}_{i}" for i in range(tensor.ndim)}

else:

def tensor_to_shape(tensor):
return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420

return flatten_unflatten_for_dynamic_shapes(
inputs, change_function=tensor_to_shape, use_dict=True
)
16 changes: 12 additions & 4 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Any, List, Tuple
from typing import Any, Callable, List, Optional, Tuple
import packaging.version as pv
import torch
import transformers
import transformers.cache_utils


def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
def flatten_unflatten_for_dynamic_shapes(
obj: Any,
use_dict: bool = False,
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
) -> Any:
"""
Returns the object in a different structure similar to what
the definition of the dynamic shapes should use.
Expand All @@ -16,18 +20,22 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
the context gives the dictionary keys but it is not expressed
in the dynamic shapes, these specifications seems to be different
for the strict and non strict mode.
:param change_function: to modifies the tensor in the structure itself,
like replace them by a shape
:return: the serialized object
"""
if isinstance(obj, torch.Tensor):
return obj
return change_function(obj) if change_function else obj
flat, spec = torch.utils._pytree.tree_flatten(obj)
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
value = flatten_unflatten_for_dynamic_shapes(
value, use_dict=use_dict, change_function=change_function
)
subtrees.append(value)
start = end
if use_dict and (spec.type is dict or spec.context):
Expand Down
Loading