Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Change Logs
0.4.0
+++++

* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,
use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes
for a specific models, it is a valid definition for ``torch.onnx.export``
which can reuse the names
* :pr:`55`: add support for text-classification
* :pr:`54`: add support for fill-mask, refactoring
* :pr:`52`: add support for zero-shot-image-classification
Expand Down
3 changes: 2 additions & 1 deletion _doc/examples/plot_export_hub_codellama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
task_from_id,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str

model_id = "codellama/CodeLlama-7b-Python-hf"
print("info", get_model_info(model_id))
Expand Down Expand Up @@ -96,7 +97,7 @@
model,
(),
kwargs=f(data["inputs"]),
dynamic_shapes=data["dynamic_shapes"],
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
strict=False,
)
print(ep)
Expand Down
3 changes: 2 additions & 1 deletion _doc/examples/plot_export_tiny_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers.rt_helper import make_feeds
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
)
Expand Down Expand Up @@ -92,7 +93,7 @@
untrained_model,
(),
kwargs=modificator(copy.deepcopy(inputs)),
dynamic_shapes=dynamic_shapes,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)

Expand Down
1 change: 1 addition & 0 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ onnx-diagnostic: investigate onnx models

The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
Sources available at `github/onnx-diagnostic <https://github.com/sdpython/onnx-diagnostic/>`_.

.. code-block:: python

Expand Down
15 changes: 8 additions & 7 deletions _unittests/ut_export/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,17 +576,20 @@ def test_couple_input_ds_cache(self):
Cls(
(),
kwargs,
{"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])},
{
"A": ds_batch,
"B": (ds_batch, [[ds_batch, ds_batch], [ds_batch, ds_batch]]),
},
).invalid_dimensions_for_export(),
)
self.assertEqual(
{"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])},
{"B": (None, [[None, {2: "d=[1]"}], [None, {2: "d=[1]"}]])},
Cls(
(),
kwargs,
{
"A": ds_batch,
"B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]),
"B": (ds_batch, [[ds_batch, ds_batch_seq], [ds_batch, ds_batch_seq]]),
},
).invalid_dimensions_for_export(),
)
Expand Down Expand Up @@ -762,10 +765,8 @@ def test_dynamic_cache_replace_by_string(self):
self.assertEqual(
{
"cache": [
{0: "Dim0", 1: "Dim1"},
{0: "Dim2", 1: "Dim3"},
{0: "Dim4", 1: "Dim5"},
{0: "Dim6", 1: "Dim7"},
[{0: "Dim0", 1: "Dim1"}, {0: "Dim2", 1: "Dim3"}],
[{0: "Dim4", 1: "Dim5"}, {0: "Dim6", 1: "Dim7"}],
]
},
as_string,
Expand Down
71 changes: 71 additions & 0 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export import CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches.patch_inputs import (
convert_dynamic_axes_into_dynamic_shapes,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class TestCacheHelpers(ExtTestCase):
def test_string_type(self):
DYN = torch.export.Dim.DYNAMIC
self.assertEqual("DYNAMIC", string_type(DYN, verbose=0))
AUTO = torch.export.Dim.AUTO
self.assertEqual("AUTO", string_type(AUTO, verbose=0))
self.assertEqual("#1[DYNAMIC]", string_type([DYN]))

batch = torch.export.Dim("batch")
dynamic_shapes = dict(
input_ids={0: batch, 1: "seq"},
attention_mask={0: batch, 1: "seq"},
position_ids={0: batch, 1: "seq"},
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
)
self.assertEqual(
"dict(input_ids:{0:Dim(batch),1:DYN(seq)},"
"attention_mask:{0:Dim(batch),1:DYN(seq)},"
"position_ids:{0:Dim(batch),1:DYN(seq)},"
"past_key_values:#2[#1[{0:Dim(batch),2:DYN(seq)}],"
"#1[{0:Dim(batch),2:DYN(seq)}]])",
string_type(dynamic_shapes),
)

def test_replace_by(self):
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_dynamic_cache(
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
)
kwargs = dict(
input_ids=torch.zeros(2, 3),
attention_mask=torch.zeros(2, 3),
position_ids=torch.zeros(2, 3),
past_key_values=past_key_values,
)
batch = torch.export.Dim("batch")
dynamic_shapes = dict(
input_ids={0: batch, 1: "seq"},
attention_mask={0: batch, 1: "seq"},
position_ids={0: batch, 1: "seq"},
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
)

DYN = torch.export.Dim.DYNAMIC
nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes
)
self.assertEqual(dynamic_shapes, nds)

with bypass_export_some_errors(patch_transformers=True):
cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes)
res = cpl.replace_string_by()
dsc = res["past_key_values"]
self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc)


if __name__ == "__main__":
unittest.main(verbosity=2)
21 changes: 21 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.patch_inputs import (
convert_dynamic_axes_into_dynamic_shapes,
use_dyn_not_str,
)


Expand Down Expand Up @@ -111,6 +112,26 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
string_type(res[1], with_shape=True),
)

def test_use_dyn_not_str(self):
batch = torch.export.Dim("batch")
dynamic_shapes = dict(
input_ids={0: batch, 1: "seq"},
attention_mask={0: batch, 1: "seq"},
position_ids={0: batch, 1: "seq"},
past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]],
)
res = use_dyn_not_str(dynamic_shapes)
DYN = torch.export.Dim.DYNAMIC
self.assertEqual(
dict(
input_ids={0: batch, 1: DYN},
attention_mask={0: batch, 1: DYN},
position_ids={0: batch, 1: DYN},
past_key_values=[[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]],
),
res,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 7 additions & 1 deletion _unittests/ut_torch_models/test_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import unittest
import packaging.version as pv
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
ignore_warnings,
requires_torch,
)
from onnx_diagnostic.torch_models.test_helper import (
get_inputs_for_task,
validate_model,
Expand Down Expand Up @@ -54,6 +59,7 @@ def test_validate_model_export(self):
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)

@requires_torch("2.7")
@hide_stdout()
@ignore_warnings(FutureWarning)
def test_validate_model_onnx(self):
Expand Down
4 changes: 2 additions & 2 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def get_parser_validate() -> ArgumentParser:

def _cmd_validate(argv: List[Any]):
from .helpers import string_type
from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean
from .torch_models.test_helper import get_inputs_for_task, validate_model
from .tasks import supported_tasks

parser = get_parser_validate()
Expand All @@ -320,7 +320,7 @@ def _cmd_validate(argv: List[Any]):
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
print("-- dynamic_shapes")
for k, v in data["dynamic_shapes"].items():
print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
print(f" + {k.ljust(max_length)}: {string_type(v)}")
else:
# Let's skip any invalid combination if known to be unsupported
if (
Expand Down
64 changes: 51 additions & 13 deletions onnx_diagnostic/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def replace_string_by(self, value: Any = None):
return self._generic_walker(
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
inputs, ds, value=value
)
),
flatten_unflatten=True,
)

@classmethod
Expand Down Expand Up @@ -135,7 +136,8 @@ def replace_by_string(self):
return self._generic_walker(
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
inputs, ds, unique=unique
)
),
flatten_unflatten=True,
)

@classmethod
Expand Down Expand Up @@ -203,7 +205,7 @@ def invalid_dimensions_for_export(self):
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
"""
return self._generic_walker(self._valid_shapes_tensor)
return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)

@classmethod
def _valid_shapes_tensor(cls, inputs, ds):
Expand All @@ -221,7 +223,9 @@ def _valid_shapes_tensor(cls, inputs, ds):
issues[i] = f"d=[{d}]"
return issues if issues else None

def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
def _generic_walker(
self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
):
"""
Generic deserializator walking through inputs and dynamic_shapes all along.
The function returns a result with the same structure as the dynamic shapes.
Expand All @@ -231,15 +235,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
f"Type mismatch, args={string_type(self.args)} and "
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
)
res = self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
res = self._generic_walker_step(
processor,
self.kwargs,
self.dynamic_shapes,
flatten_unflatten=flatten_unflatten,
)
return (tuple(), res) if args_kwargs else res

if not self.kwargs:
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
f"Type mismatch, args={string_type(self.args)} and "
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
)
res = self._generic_walker_step(processor, self.args, self.dynamic_shapes)
res = self._generic_walker_step(
processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
)
return (res, {}) if args_kwargs else res

assert isinstance(self.dynamic_shapes, dict), (
Expand All @@ -250,12 +261,22 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
self.dynamic_shapes
):
# No dynamic shapes for the positional arguments.
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
return self._generic_walker_step(
processor,
self.kwargs,
self.dynamic_shapes,
flatten_unflatten=flatten_unflatten,
)

if isinstance(self.args_names, list):
if not set(self.args_names) & set(self.dynamic_shapes):
# No dynamic shapes for the positional arguments.
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
return self._generic_walker_step(
processor,
self.kwargs,
self.dynamic_shapes,
flatten_unflatten=flatten_unflatten,
)

assert self.args_names, (
"args and kwargs are filled, then args_names must be specified in "
Expand All @@ -268,7 +289,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
)
kwargs = dict(zip(self.args_names, self.args))
kwargs.update(self.kwargs)
res = self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
res = self._generic_walker_step(
processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
)
if args_kwargs:
pgs = [None for _ in range(len(self.args))]
kws = {}
Expand All @@ -286,7 +309,9 @@ def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
)

@classmethod
def _generic_walker_step(cls, processor: Callable, inputs, ds):
def _generic_walker_step(
cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
):
if isinstance(inputs, torch.Tensor):
return processor(inputs, ds)
if isinstance(inputs, (int, float, str)):
Expand All @@ -303,7 +328,11 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
if isinstance(inputs, (tuple, list)):
value = []
for i, d in zip(inputs, ds):
value.append(cls._generic_walker_step(processor, i, d))
value.append(
cls._generic_walker_step(
processor, i, d, flatten_unflatten=flatten_unflatten
)
)
return (
(value if isinstance(ds, list) else tuple(value))
if any(v is not None for v in value)
Expand All @@ -314,7 +343,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
dvalue = {}
for k, v in inputs.items():
t = cls._generic_walker_step(processor, v, ds[k])
t = cls._generic_walker_step(
processor, v, ds[k], flatten_unflatten=flatten_unflatten
)
if t is not None:
dvalue[k] = t
return dvalue if dvalue else None
Expand All @@ -325,11 +356,18 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
f"torch.utils._pytree.register_pytree_node, it is not possible to "
f"map this class with the given dynamic shapes."
)
if flatten_unflatten:
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
return cls._generic_walker_step(
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
)
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
if all(isinstance(t, torch.Tensor) for t in flat):
# We need to flatten dynamic shapes as well
ds = flatten_dynamic_shapes(ds)
return cls._generic_walker_step(processor, flat, ds)
return cls._generic_walker_step(
processor, flat, ds, flatten_unflatten=flatten_unflatten
)

class ChangeDimensionProcessor:
def __init__(self, desired_values):
Expand Down
Loading
Loading