From 28f35155416b0ae0f42111fcddfbe20ddfaac5f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 12:00:47 +0100 Subject: [PATCH 01/20] Changes Cache serialization --- .../test_patch_serialization_transformers.py | 23 +- onnx_diagnostic/helpers/cache_helper.py | 22 +- onnx_diagnostic/helpers/helper.py | 5 +- .../serialization/transformers_impl.py | 244 ++++++++++-------- 4 files changed, 177 insertions(+), 117 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index a54abbe9..352be02a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -8,6 +8,7 @@ make_static_cache, make_sliding_window_cache, flatten_unflatten_for_dynamic_shapes, + make_dynamic_shapes_kv_cache, CacheKeyValue, ) from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( @@ -64,8 +65,8 @@ def forward(self, cache): model(cache) DYN = torch.export.Dim.DYNAMIC ds = [ - [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], - [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], + make_dynamic_shapes_kv_cache(cache1, {0: DYN}), + make_dynamic_shapes_kv_cache(cache2, {0: DYN}), ] with torch_export_patches(patch_transformers=True): @@ -99,9 +100,15 @@ def forward(self, cache): model = Model() model(cache) DYN = torch.export.Dim.DYNAMIC - ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] + ds = make_dynamic_shapes_kv_cache(cache, {0: DYN}) + self.assertEqual(len(ds), 6) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): + flat, _spec = torch.utils._pytree.tree_flatten(cache) + self.assertEqual(len(flat), len(ds)) + unflat = torch.utils._pytree.tree_unflatten(flat, _spec) + if hasattr(unflat, "layers"): + self.assertEqual(len(unflat.layers), 3) torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) @@ -195,7 +202,7 @@ def forward(self, cache): model = Model() model(cache) DYN = torch.export.Dim.DYNAMIC - ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]] + ds = make_dynamic_shapes_kv_cache(cache, {0: DYN}) with torch_export_patches(patch_transformers=True): torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @@ -265,9 +272,7 @@ def test_static_cache(self): flat, _spec = torch.utils._pytree.tree_flatten(bo) unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) self.assertIsInstance(unflat, list) - self.assertEqual( - "#2[#3[T1r4,T1r4,T1r4],#3[T1r4,T1r4,T1r4]]", self.string_type(unflat) - ) + self.assertEqual("#6[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]", self.string_type(unflat)) # export class Model(torch.nn.Module): @@ -278,7 +283,7 @@ def forward(self, cache): model = Model() model(bo) DYN = torch.export.Dim.DYNAMIC - ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] + ds = make_dynamic_shapes_kv_cache(bo, {0: DYN}) with torch_export_patches(patch_transformers=True, stop_if_static=1): torch.export.export(model, (bo,), dynamic_shapes=(ds,)) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 5d54a289..0ab9014a 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import packaging.version as pv import torch import transformers @@ -46,9 +46,14 @@ def __init__(self, cache=None): raise NotImplementedError(f"type(cache)={type(cache)}") def make_dynamic_cache(self): - """Do the reverse operation.""" + """Does the reverse operation.""" return make_dynamic_cache(list(zip(self.key_cache, self.value_cache))) + @property + def n_layers(self) -> int: + """Returns the number of layers.""" + return len(self.key_cache) if self.key_cache else 0 + def flatten_unflatten_for_dynamic_shapes( obj: Any, @@ -134,6 +139,19 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool: return len(cache2.key_cache) == len(cache.value_cache) +def make_dynamic_shapes_kv_cache( + cache: transformers.cache_utils.Cache, shape_of_one: Dict[str, Any] +) -> List[Dict[int, Any]]: + """ + Returns the dynamic shapes for key-value cache + + :param cache: a cache + :param shape_of_one: shape of one element + :return: dynamic shapes + """ + return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)] + + if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): def make_dynamic_cache( diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 8b372dcc..d140c890 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1,6 +1,7 @@ import ast import enum import inspect +import itertools from dataclasses import is_dataclass, fields from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np @@ -948,8 +949,8 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any: from .cache_helper import CacheKeyValue kc = CacheKeyValue(x) - res = flatten_object(kc.key_cache) + flatten_object(kc.value_cache) - return tuple(res) + return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache))) + if x.__class__.__name__ == "EncoderDecoderCache": res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache) return tuple(res) diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 0d19b120..548bc1dd 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -1,13 +1,20 @@ -from typing import Any, List, Set, Tuple +import itertools +from typing import Any, Callable, List, Set, Tuple import torch from transformers.cache_utils import ( + Cache, DynamicCache, EncoderDecoderCache, HybridCache, - SlidingWindowCache, StaticCache, ) +try: + from transformers.cache_utils import SlidingWindowCache +except ImportError: + SlidingWindowCache = None + + try: from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: @@ -30,66 +37,36 @@ } -############ -# MambaCache -############ - - -def flatten_mamba_cache( - mamba_cache: MambaCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - assert isinstance(mamba_cache.conv_states, list) and isinstance( - mamba_cache.ssm_states, list - ), ( - f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, " - f"{type(mamba_cache.ssm_states)}" +def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]: + ca = CacheKeyValue(cache) + flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache))) + keys = list( + itertools.chain.from_iterable( + (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache)) + ) ) - flat = [ - ("conv_states", mamba_cache.conv_states), - ("ssm_states", mamba_cache.ssm_states), - ] - return [f[1] for f in flat], [f[0] for f in flat] + return flat, keys -def unflatten_mamba_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> MambaCache: - """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" - conv_states, ssm_states = values - - class _config: - def __init__(self): - if isinstance(conv_states, list): - self.intermediate_size = conv_states[0].shape[1] - self.state_size = ssm_states[0].shape[2] - self.conv_kernel = conv_states[0].shape[2] - self.num_hidden_layers = len(conv_states) - else: - self.intermediate_size = conv_states.shape[2] - self.state_size = ssm_states.shape[3] - self.conv_kernel = conv_states.shape[3] - self.num_hidden_layers = conv_states.shape[0] - - cache = MambaCache( - _config(), - max_batch_size=1, - dtype=values[-1][0].dtype, - device="cpu" if values[-1][0].get_device() < 0 else "cuda", - ) - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache +def _flatten_with_keys_cache( + cache: Cache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + values, context = _flatten_key_value_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context -def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - values, context = flatten_mamba_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context +def _unflatten_cache( + make_cache: Callable, + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> DynamicCache: + """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" + res = make_cache(list(zip(values[::2], values[1::2]))) + assert output_type is None or isinstance( + res, output_type + ), f"Type mismatch between {output_type} (expected) and {type(res)}" + return res ############## @@ -101,24 +78,21 @@ def flatten_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - ca = CacheKeyValue(dynamic_cache) - flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] + return _flatten_key_value_cache(dynamic_cache) def flatten_with_keys_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - values, context = flatten_dynamic_cache(dynamic_cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + return _flatten_with_keys_cache(dynamic_cache) def unflatten_dynamic_cache( values: List[Any], context: torch.utils._pytree.Context, output_type=None ) -> DynamicCache: """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" - return make_dynamic_cache(list(zip(values[0], values[1]))) + return _unflatten_cache(make_dynamic_cache, values, context, output_type=output_type) ############# @@ -130,24 +104,21 @@ def flatten_hybrid_cache( cache: HybridCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" - ca = CacheKeyValue(cache) - flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] + return _flatten_key_value_cache(cache) def flatten_with_keys_hybrid_cache( cache: HybridCache, ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" - values, context = flatten_hybrid_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + return _flatten_with_keys_cache(cache) def unflatten_hybrid_cache( values: List[Any], context: torch.utils._pytree.Context, output_type=None ) -> HybridCache: """Restores a :class:`transformers.cache_utils.HybridCache` from python objects.""" - return make_hybrid_cache(list(zip(values[0], values[1]))) + return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type) ############# @@ -163,26 +134,27 @@ def flatten_static_cache( assert not ca.key_cache or cache.max_cache_len == ca.key_cache[0].shape[2], ( f"Serialization doet not work when " f"cache.max_cache_len={cache.max_cache_len} != " - f"cache.key_cache[0].shape[2]={ca.keu_cache[0].shape[2]}" + f"cache.key_cache[0].shape[2]={ca.key_cache[0].shape[2]}" ) - flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] + return _flatten_key_value_cache(cache) def flatten_with_keys_static_cache( cache: StaticCache, ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" - values, context = flatten_static_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + return _flatten_with_keys_cache(cache) def unflatten_static_cache( values: List[Any], context: torch.utils._pytree.Context, output_type=None ) -> StaticCache: """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" - return make_static_cache( - list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2] + return _unflatten_cache( + lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]), + values, + context, + output_type=output_type, ) @@ -191,34 +163,36 @@ def unflatten_static_cache( #################### -def flatten_sliding_window_cache( - cache: SlidingWindowCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.SlidingWindowCache` - with python objects. - """ - ca = CacheKeyValue(cache) - flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)] - return [f[1] for f in flat], [f[0] for f in flat] - - -def flatten_with_keys_sliding_window_cache( - cache: SlidingWindowCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """ - Serializes a :class:`transformers.cache_utils.SlidingWindowCache` - with python objects. - """ - values, context = flatten_sliding_window_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - -def unflatten_sliding_window_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> SlidingWindowCache: - """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" - return make_sliding_window_cache(list(zip(values[0], values[1]))) +if SlidingWindowCache: + + def flatten_sliding_window_cache( + cache: SlidingWindowCache, + ) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + return _flatten_key_value_cache(cache) + + def flatten_with_keys_sliding_window_cache( + cache: SlidingWindowCache, + ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + return _flatten_with_keys_cache(cache) + + def unflatten_sliding_window_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None + ) -> SlidingWindowCache: + """ + Restores a :class:`transformers.cache_utils.SlidingWindowCache` + from python objects. + """ + return _unflatten_cache( + make_sliding_window_cache, values, context, output_type=output_type + ) ##################### @@ -265,6 +239,68 @@ def unflatten_encoder_decoder_cache( ) +############ +# MambaCache +############ + + +def flatten_mamba_cache( + mamba_cache: MambaCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + assert isinstance(mamba_cache.conv_states, list) and isinstance( + mamba_cache.ssm_states, list + ), ( + f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, " + f"{type(mamba_cache.ssm_states)}" + ) + flat = [ + ("conv_states", mamba_cache.conv_states), + ("ssm_states", mamba_cache.ssm_states), + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def unflatten_mamba_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> MambaCache: + """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" + conv_states, ssm_states = values + + class _config: + def __init__(self): + if isinstance(conv_states, list): + self.intermediate_size = conv_states[0].shape[1] + self.state_size = ssm_states[0].shape[2] + self.conv_kernel = conv_states[0].shape[2] + self.num_hidden_layers = len(conv_states) + else: + self.intermediate_size = conv_states.shape[2] + self.state_size = ssm_states.shape[3] + self.conv_kernel = conv_states.shape[3] + self.num_hidden_layers = conv_states.shape[0] + + cache = MambaCache( + _config(), + max_batch_size=1, + dtype=values[-1][0].dtype, + device="cpu" if values[-1][0].get_device() < 0 else "cuda", + ) + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + values, context = flatten_mamba_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + ############# # dataclasses ############# From 3251a8ca3879caa07e75ded8e62e62240e922a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 12:02:35 +0100 Subject: [PATCH 02/20] mypy --- onnx_diagnostic/helpers/cache_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 0ab9014a..f3fdc5ef 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -140,7 +140,7 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool: def make_dynamic_shapes_kv_cache( - cache: transformers.cache_utils.Cache, shape_of_one: Dict[str, Any] + cache: transformers.cache_utils.Cache, shape_of_one: Dict[int, Any] ) -> List[Dict[int, Any]]: """ Returns the dynamic shapes for key-value cache From 519587549ce1d39bdaaa8b2c15ca79889e2c4429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 13:42:21 +0100 Subject: [PATCH 03/20] fix --- _unittests/ut_export/test_shape_helper.py | 6 ++++-- .../tasks/automatic_speech_recognition.py | 10 ++-------- onnx_diagnostic/tasks/feature_extraction.py | 10 ++-------- onnx_diagnostic/tasks/image_text_to_text.py | 5 +---- onnx_diagnostic/tasks/summarization.py | 10 ++-------- onnx_diagnostic/tasks/text2text_generation.py | 10 ++-------- onnx_diagnostic/tasks/text_generation.py | 15 +++------------ 7 files changed, 16 insertions(+), 50 deletions(-) diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 95b081bd..86e18845 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -225,8 +225,10 @@ def test_make_fake_with_dynamic_dimensions_whole(self): "attention_mask": {0: "batch", 1: "cache+seq"}, "position_ids": {0: "batch", 1: "seq_length"}, "past_key_values": [ - [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], - [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], + {0: "batch", 2: "cache_length"}, + {0: "batch", 2: "cache_length"}, + {0: "batch", 2: "cache_length"}, + {0: "batch", 2: "cache_length"}, ], }, ) diff --git a/onnx_diagnostic/tasks/automatic_speech_recognition.py b/onnx_diagnostic/tasks/automatic_speech_recognition.py index c122c086..22724f25 100644 --- a/onnx_diagnostic/tasks/automatic_speech_recognition.py +++ b/onnx_diagnostic/tasks/automatic_speech_recognition.py @@ -84,14 +84,8 @@ def get_inputs( "cache_position": {0: seq_length}, "encoder_outputs": [{0: batch}], # last_hidden_state "past_key_values": [ - [ - [{0: batch} for _ in range(num_hidden_layers)], - [{0: batch} for _ in range(num_hidden_layers)], - ], - [ - [{0: batch} for _ in range(num_hidden_layers)], - [{0: batch} for _ in range(num_hidden_layers)], - ], + [{0: batch} for _ in range(num_hidden_layers * 2)], + [{0: batch} for _ in range(num_hidden_layers * 2)], ], } inputs = dict( diff --git a/onnx_diagnostic/tasks/feature_extraction.py b/onnx_diagnostic/tasks/feature_extraction.py index b049a5b6..58b1e3c5 100644 --- a/onnx_diagnostic/tasks/feature_extraction.py +++ b/onnx_diagnostic/tasks/feature_extraction.py @@ -109,14 +109,8 @@ def get_inputs( cache_length = "cache_length_key" cache_length2 = "cache_length_val" shapes["past_key_values"] = [ # type: ignore[assignment] - [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - ], - [ - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)], - ], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)], + [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)], ] res = dict(inputs=inputs, dynamic_shapes=shapes) diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 0bb8a4e9..3ad79109 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -151,10 +151,7 @@ def _get_inputs_gemma3( }, "position_ids": {0: batch, 1: seq_length}, "cache_position": {0: seq_length}, - "past_key_values": [ - [{0: batch} for _ in range(num_hidden_layers)], - [{0: batch} for _ in range(num_hidden_layers)], - ], + "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)], "pixel_values": {0: batch}, "use_cache": None, } diff --git a/onnx_diagnostic/tasks/summarization.py b/onnx_diagnostic/tasks/summarization.py index 5760c41c..fe9c8138 100644 --- a/onnx_diagnostic/tasks/summarization.py +++ b/onnx_diagnostic/tasks/summarization.py @@ -81,14 +81,8 @@ def get_inputs( "attention_mask": {0: batch, 1: "seq_mask"}, # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC}, "past_key_values": [ - [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - ], - [ - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)], - ], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)], + [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)], ], # one these is selected based on the forward method signature # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC}, diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index fc8cd2e0..a051e4fd 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -83,14 +83,8 @@ def get_inputs( "attention_mask": {0: batch, 1: "seq_mask"}, # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC}, "past_key_values": [ - [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - ], - [ - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)], - ], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)], + [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)], ], # one these is selected based on the forward method signature # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC}, diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index eebd4aa9..6b38da5c 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -119,10 +119,7 @@ def get_inputs( 0: batch, 1: "cache+seq", # cache_length + seq_length }, - "cache_params": [ - [{0: batch} for _ in range(num_hidden_layers)], - [{0: batch} for _ in range(num_hidden_layers)], - ], + "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)], } inputs = dict( input_ids=torch.randint( @@ -176,12 +173,7 @@ def get_inputs( "input_ids": {0: batch, 1: seq_length}, "attention_mask": {0: batch, 2: "seq"}, "cache_position": {0: "seq"}, - "past_key_values": [ - # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch} for _ in range(num_hidden_layers)], - [{0: batch} for _ in range(num_hidden_layers)], - ], + "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)], } inputs = dict( input_ids=torch.randint( @@ -222,8 +214,7 @@ def get_inputs( }, "position_ids": {0: batch, 1: seq_length}, "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + {0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2) ], } From 627a5963ad738e09bc69d6c372f1470eca6cf010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 14:55:44 +0100 Subject: [PATCH 04/20] other fixes --- _unittests/ut_export/test_shape_helper.py | 8 ++-- _unittests/ut_helpers/test_rt_helper.py | 45 ++++++++++++++++++- _unittests/ut_tasks/test_tasks.py | 4 +- .../test_patch_inputs.py | 8 ++-- onnx_diagnostic/export/shape_helper.py | 6 ++- onnx_diagnostic/helpers/rt_helper.py | 8 ++-- .../torch_models/untrained/llm_phi2.py | 5 +-- 7 files changed, 62 insertions(+), 22 deletions(-) diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 86e18845..c533716e 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -155,8 +155,8 @@ def test_all_dynamic_shapes_from_inputs_dynamic_cache(self): "attention_mask": {0: "d_1_0", 1: "d_1_1"}, "position_ids": {0: "d_2_0", 1: "d_2_1"}, "past_key_values": [ - [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}], - [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}], + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, ], }, ds, @@ -176,8 +176,8 @@ def test_guess_dynamic_shapes_from_inputs(self): "attention_mask": {0: "dd_0I0", 1: "dd_0I1"}, "input_ids": {0: "dd_1I0", 1: "dd_1I1"}, "past_key_values": [ - [{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}], - [{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}], + {0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}, + {0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}, ], "position_ids": {0: "dd_3I0", 1: "dd_3I1"}, }, diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index d24f8900..aa0f0155 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -5,13 +5,42 @@ from onnx_diagnostic.helpers.rt_helper import onnx_generate from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.export.api import to_onnx class TestRtSession(ExtTestCase): + def simple_generate_with_cache( + self, model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100 + ): + # First call: prefill + outputs = model( + input_ids, + attention_mask=torch.ones( + input_ids.shape, dtype=input_ids.dtype, device=input_ids.device + ), + use_cache=True, + ) + + # Next calls: decode + for _ in range(max_new_tokens): + next_token_logits = outputs.logits[:, -1, :] + past_key_values = outputs.past_key_values + next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) + if next_token_id.item() == eos_token_id: + break + input_ids = torch.cat([input_ids, next_token_id], dim=-1) + outputs = model( + next_token_id, + use_cache=True, + past_key_values=past_key_values, + attention_mask=torch.ones( + input_ids.shape, dtype=input_ids.dtype, device=input_ids.device + ), + ) + return input_ids + @hide_stdout() def test_onnx_generate(self): - from experimental_experiment.torch_interpreter import to_onnx - mid = "arnir0/Tiny-LLM" print("-- test_onnx_generate: get model") data = get_untrained_model_with_inputs(mid) @@ -19,6 +48,7 @@ def test_onnx_generate(self): del inputs["position_ids"] del ds["position_ids"] input_ids = inputs["input_ids"] + print("----", input_ids.shape) folder = self.get_dump_folder("test_onnx_generate") model_name = os.path.join(folder, "model.onnx") print("-- test_onnx_generate: export model") @@ -29,13 +59,24 @@ def test_onnx_generate(self): kwargs=inputs, dynamic_shapes=ds, filename=model_name, + exporter="custom", ) print("-- test_onnx_generate: generate") res = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10) + n_inputs = input_ids.shape[1] + self.assertEqualArray(input_ids[:1], res[:, :n_inputs]) self.assertEqual(res.dtype, torch.int64) self.assertEqual(res.shape, (1, 13)) print("-- test_onnx_generate: done") + # expected = model.generate(input_ids[:1], max_new_tokens=10) + expected = self.simple_generate_with_cache(model, input_ids[:1], 2, max_new_tokens=10) + self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) + print("******", res) + print("******", expected) + self.assertEqual(expected.dtype, torch.int64) + self.assertEqual(expected.shape, (1, 13)) + self.assertEqualArray(expected, res) if __name__ == "__main__": diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 0079b7e9..903b4d68 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -113,8 +113,8 @@ def test_automatic_speech_recognition_float32(self): "cache_position": {0: "seq_length"}, "encoder_outputs": [{0: "batch"}], "past_key_values": [ - [[{0: "batch"}, {0: "batch"}], [{0: "batch"}, {0: "batch"}]], - [[{0: "batch"}, {0: "batch"}], [{0: "batch"}, {0: "batch"}]], + [{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}], + [{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}], ], }, ds, diff --git a/_unittests/ut_torch_export_patches/test_patch_inputs.py b/_unittests/ut_torch_export_patches/test_patch_inputs.py index 400902a9..d454651d 100644 --- a/_unittests/ut_torch_export_patches/test_patch_inputs.py +++ b/_unittests/ut_torch_export_patches/test_patch_inputs.py @@ -48,8 +48,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_1(self): "attention_mask": {0: "batch_size", 1: "total_sequence_length"}, "input_ids": {0: "batch_size", 1: "sequence_length"}, "past_key_values": [ - [{0: "batch_size", 2: "past_sequence_length"}], - [{0: "batch_size", 2: "past_sequence_length"}], + {0: "batch_size", 2: "past_sequence_length"}, + {0: "batch_size", 2: "past_sequence_length"}, ], "position_ids": {0: "batch_size", 1: "sequence_length"}, }, @@ -98,8 +98,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self): "attention_mask": {0: "batch_size", 1: "sequence_length"}, "input_ids": {0: "batch_size", 1: "sequence_length"}, "past_key_values": [ - [{0: "batch_size", 2: "past_sequence_length"}], - [{0: "batch_size", 2: "past_sequence_length"}], + {0: "batch_size", 2: "past_sequence_length"}, + {0: "batch_size", 2: "past_sequence_length"}, ], "position_ids": {0: "batch_size", 1: "sequence_length"}, }, diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py index e27131f1..8a07ba3b 100644 --- a/onnx_diagnostic/export/shape_helper.py +++ b/onnx_diagnostic/export/shape_helper.py @@ -260,8 +260,10 @@ def make_fake_with_dynamic_dimensions( "attention_mask": {0: "batch", 1: "cache+seq"}, "position_ids": {0: "batch", 1: "seq_length"}, "past_key_values": [ - [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], - [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], + {0: "batch", 2: "cache_length"}, + {0: "batch", 2: "cache_length"}, + {0: "batch", 2: "cache_length"}, + {0: "batch", 2: "cache_length"}, ], }, ) diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 69440a1d..763c5401 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -206,7 +206,7 @@ def onnx_generate( ), f"Only text generation is supported but input_names == {input_names}" # First call: prefill - input_feeds = dict( + feeds = dict( input_ids=input_ids, attention_mask=torch.ones( input_ids.shape, dtype=input_ids.dtype, device=input_ids.device @@ -216,9 +216,9 @@ def onnx_generate( new_shape = tuple( _get_dim(i, s, batch=input_ids.shape[0]) for i, s in enumerate(shape) ) - input_feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype)) + feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype)) - outputs = session.run(None, input_feeds) + outputs = session.run(None, feeds) # Next calls: decode for _ in range(max_new_tokens): @@ -241,7 +241,7 @@ def onnx_generate( ), ) feeds.update(dict(zip(input_names[2:], outputs[1:]))) - outputs = session.run(None, input_feeds) + outputs = session.run(None, feeds) if return_session: return input_ids, session diff --git a/onnx_diagnostic/torch_models/untrained/llm_phi2.py b/onnx_diagnostic/torch_models/untrained/llm_phi2.py index 0c7f73f0..84cd22c2 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_phi2.py +++ b/onnx_diagnostic/torch_models/untrained/llm_phi2.py @@ -84,10 +84,7 @@ def get_phi2( 0: batch, 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length }, - "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(n_layers)], - [{0: batch, 2: cache_length} for _ in range(n_layers)], - ], + "past_key_values": [{0: batch, 2: cache_length} for _ in range(n_layers * 2)], } inputs = dict( input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to( From e52ec0ba698643d7d75747ec0bef4941640ca933 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 16:30:09 +0100 Subject: [PATCH 05/20] fix other tests --- _unittests/ut_export/test_dynamic_shapes.py | 51 +++++++------ _unittests/ut_export/test_serialization.py | 21 ++---- _unittests/ut_export/test_shape_helper.py | 72 ++++++++----------- _unittests/ut_helpers/test_cache_helper.py | 25 ++++--- _unittests/ut_helpers/test_helper.py | 2 +- _unittests/ut_tasks/test_tasks.py | 4 +- .../test_dynamic_class.py | 5 +- onnx_diagnostic/export/shape_helper.py | 8 +-- 8 files changed, 79 insertions(+), 109 deletions(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 52b758f1..089b3f23 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -452,19 +452,18 @@ def forward(self, cache, z): ( ( [ - [{}, {}], - [ - { - 0: torch.export.Dim.DYNAMIC, - 2: torch.export.Dim.DYNAMIC, - 3: torch.export.Dim.DYNAMIC, - }, - { - 0: torch.export.Dim.DYNAMIC, - 2: torch.export.Dim.DYNAMIC, - 3: torch.export.Dim.DYNAMIC, - }, - ], + {}, + { + 0: torch.export.Dim.DYNAMIC, + 2: torch.export.Dim.DYNAMIC, + 3: torch.export.Dim.DYNAMIC, + }, + {}, + { + 0: torch.export.Dim.DYNAMIC, + 2: torch.export.Dim.DYNAMIC, + 3: torch.export.Dim.DYNAMIC, + }, ], {3: torch.export.Dim.DYNAMIC}, ), @@ -520,11 +519,10 @@ def forward(self, cache, z): ( ( [ - [{}, {}], - [ - {0: "dim_0I_1o_0l0", 2: "dim_0I_1o_0l2", 3: "dim_0I_1o_0l3"}, - {0: "dim_0I_1o_1l0", 2: "dim_0I_1o_1l2", 3: "dim_0I_1o_1l3"}, - ], + {}, + {0: "dim_0I_1o0", 2: "dim_0I_1o2", 3: "dim_0I_1o3"}, + {}, + {0: "dim_0I_3o0", 2: "dim_0I_3o2", 3: "dim_0I_3o3"}, ], {3: "dim_1I3"}, ), @@ -641,18 +639,18 @@ def test_couple_input_ds_cache(self): kwargs, { "A": ds_batch, - "B": (ds_batch, [[ds_batch, ds_batch], [ds_batch, ds_batch]]), + "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch]), }, ).invalid_dimensions_for_export(), ) self.assertEqual( - {"B": (None, [[None, {2: "d=[1]"}], [None, {2: "d=[1]"}]])}, + {"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])}, Cls( (), kwargs, { "A": ds_batch, - "B": (ds_batch, [[ds_batch, ds_batch_seq], [ds_batch, ds_batch_seq]]), + "B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]), }, ).invalid_dimensions_for_export(), ) @@ -831,18 +829,17 @@ def test_dynamic_cache_replace_by_string(self): DYN = torch.export.Dim.DYNAMIC ds = { - "cache": [ - [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], - [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], - ] + "cache": [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}] } inst = CoupleInputsDynamicShapes((), dict(cache=cache), ds) as_string = inst.replace_by_string() self.assertEqual( { "cache": [ - [{0: "Dim0", 1: "Dim1"}, {0: "Dim2", 1: "Dim3"}], - [{0: "Dim4", 1: "Dim5"}, {0: "Dim6", 1: "Dim7"}], + {0: "Dim0", 1: "Dim1"}, + {0: "Dim2", 1: "Dim3"}, + {0: "Dim4", 1: "Dim5"}, + {0: "Dim6", 1: "Dim7"}, ] }, as_string, diff --git a/_unittests/ut_export/test_serialization.py b/_unittests/ut_export/test_serialization.py index 07599bb3..070a76f1 100644 --- a/_unittests/ut_export/test_serialization.py +++ b/_unittests/ut_export/test_serialization.py @@ -30,7 +30,7 @@ def forward(self, cache): cache = self._get_cache() DYN = torch.export.Dim.DYNAMIC ds = {0: DYN, 1: DYN, 3: DYN} - dynamic_shapes = ([[ds, ds], [ds, ds]],) + dynamic_shapes = ([ds, ds, ds, ds],) with torch_export_patches(patch_transformers=True): exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes) self.assertNotEmpty(exp) @@ -44,7 +44,7 @@ def forward(self, cache): cache = self._get_cache() flat_unflat = flatten_unflatten_for_dynamic_shapes(cache) s = string_type(flat_unflat, with_shape=True) - self.assertEqual("#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]", s) + self.assertEqual("#4[T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7]", s) def test_dynamic_cache_bypass(self): class Model(torch.nn.Module): @@ -55,7 +55,7 @@ def forward(self, cache): with torch_export_patches(patch_transformers=True): flat_unflat = flatten_unflatten_for_dynamic_shapes(cache) s = string_type(flat_unflat, with_shape=True) - self.assertEqual("#2[#2[T1s2x4x1x7,T1s2x4x1x7],#2[T1s2x4x1x7,T1s2x4x1x7]]", s) + self.assertEqual("#4[T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7,T1s2x4x1x7]", s) def test_dynamic_cache_guess_static(self): class Model(torch.nn.Module): @@ -65,7 +65,7 @@ def forward(self, cache): cache = self._get_cache() md = ModelInputs(Model(), [(cache,)]) guessed = md.guess_dynamic_shapes() - self.assertEqual(guessed, (([[{}, {}], [{}, {}]],), {})) + self.assertEqual(guessed, (([{}, {}, {}, {}],), {})) def test_dynamic_cache_guess_auto(self): class Model(torch.nn.Module): @@ -77,7 +77,7 @@ def forward(self, cache): guessed = md.guess_dynamic_shapes(auto=True) AUTO = torch.export.Dim.AUTO ds = {i: AUTO for i in range(4)} # noqa: C420 - self.assertEqual(guessed, (([[ds, ds], [ds, ds]],), {})) + self.assertEqual(guessed, (([ds, ds, ds, ds],), {})) def test_dynamic_cache_guess_dynamic(self): class Model(torch.nn.Module): @@ -88,18 +88,11 @@ def forward(self, cache): Model(), [(self._get_cache(),), (self._get_cache(bsize=3, nheads=5),)] ) guessed = md.guess_dynamic_shapes() + print("****", guessed) DYN = torch.export.Dim.DYNAMIC self.assertEqual( + (([{0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN}],), {}), guessed, - ( - ( - [ - [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], - [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], - ], - ), - {}, - ), ) diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index c533716e..bd43828d 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -23,14 +23,14 @@ class TestShapeHelper(ExtTestCase): def test_all_dynamic_shape_from_cache(self): cache = make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]) ds = all_dynamic_shapes_from_inputs(cache) - self.assertEqual([[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], ds) + self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds) @requires_torch("2.7.99") def test_all_dynamic_shape_all_transformers_cache(self): caches = [ ( make_dynamic_cache([(torch.ones((2, 2)), (torch.ones((2, 2)) * 2))]), - [[{0: "d_0_0", 1: "d_0_1"}], [{0: "d_1_0", 1: "d_1_1"}]], + [{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ), ( make_encoder_decoder_cache( @@ -51,28 +51,20 @@ def test_all_dynamic_shape_all_transformers_cache(self): ), [ [ - [ - {0: "d_0_0", 1: "d_0_1", 2: "d_0_2"}, - {0: "d_1_0", 1: "d_1_1", 2: "d_1_2"}, - {0: "d_2_0", 1: "d_2_1", 2: "d_2_2"}, - ], - [ - {0: "d_3_0", 1: "d_3_1", 2: "d_3_2"}, - {0: "d_4_0", 1: "d_4_1", 2: "d_4_2"}, - {0: "d_5_0", 1: "d_5_1", 2: "d_5_2"}, - ], + {0: "d_0_0", 1: "d_0_1", 2: "d_0_2"}, + {0: "d_1_0", 1: "d_1_1", 2: "d_1_2"}, + {0: "d_2_0", 1: "d_2_1", 2: "d_2_2"}, + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2"}, + {0: "d_5_0", 1: "d_5_1", 2: "d_5_2"}, ], [ - [ - {0: "d_6_0", 1: "d_6_1", 2: "d_6_2"}, - {0: "d_7_0", 1: "d_7_1", 2: "d_7_2"}, - {0: "d_8_0", 1: "d_8_1", 2: "d_8_2"}, - ], - [ - {0: "d_9_0", 1: "d_9_1", 2: "d_9_2"}, - {0: "d_10_0", 1: "d_10_1", 2: "d_10_2"}, - {0: "d_11_0", 1: "d_11_1", 2: "d_11_2"}, - ], + {0: "d_6_0", 1: "d_6_1", 2: "d_6_2"}, + {0: "d_7_0", 1: "d_7_1", 2: "d_7_2"}, + {0: "d_8_0", 1: "d_8_1", 2: "d_8_2"}, + {0: "d_9_0", 1: "d_9_1", 2: "d_9_2"}, + {0: "d_10_0", 1: "d_10_1", 2: "d_10_2"}, + {0: "d_11_0", 1: "d_11_1", 2: "d_11_2"}, ], ], ), @@ -85,16 +77,12 @@ def test_all_dynamic_shape_all_transformers_cache(self): ] ), [ - [ - {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, - {0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"}, - {0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"}, - ], - [ - {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, - {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, - {0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"}, - ], + {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, + {0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"}, + {0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"}, + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, + {0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"}, ], ), ( @@ -107,16 +95,12 @@ def test_all_dynamic_shape_all_transformers_cache(self): max_cache_len=15, ), [ - [ - {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, - {0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"}, - {0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"}, - ], - [ - {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, - {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, - {0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"}, - ], + {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, + {0: "d_1_0", 1: "d_1_1", 2: "d_1_2", 3: "d_1_3"}, + {0: "d_2_0", 1: "d_2_1", 2: "d_2_2", 3: "d_2_3"}, + {0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}, + {0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}, + {0: "d_5_0", 1: "d_5_1", 2: "d_5_2", 3: "d_5_3"}, ], ), ] @@ -176,8 +160,8 @@ def test_guess_dynamic_shapes_from_inputs(self): "attention_mask": {0: "dd_0I0", 1: "dd_0I1"}, "input_ids": {0: "dd_1I0", 1: "dd_1I1"}, "past_key_values": [ - {0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}, - {0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}, + {0: "dd_2I_0o0", 2: "dd_2I_0o2"}, + {0: "dd_2I_1o0", 2: "dd_2I_1o2"}, ], "position_ids": {0: "dd_3I0", 1: "dd_3I1"}, }, diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 23563875..5f9e0c92 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -73,7 +73,7 @@ def test_replace_by(self): input_ids={0: batch, 1: "seq"}, attention_mask={0: batch, 1: "seq"}, position_ids={0: batch, 1: "seq"}, - past_key_values=[[{0: batch, 2: "seq"}], [{0: batch, 2: "seq"}]], + past_key_values=[{0: batch, 2: "seq"}, {0: batch, 2: "seq"}], ) DYN = torch.export.Dim.DYNAMIC @@ -86,7 +86,7 @@ def test_replace_by(self): cpl = CoupleInputsDynamicShapes(tuple(), kwargs, dynamic_shapes) res = cpl.replace_string_by() dsc = res["past_key_values"] - self.assertEqual([[{0: batch, 2: DYN}], [{0: batch, 2: DYN}]], dsc) + self.assertEqual([{0: batch, 2: DYN}, {0: batch, 2: DYN}], dsc) def test_unflatten_flatten_dynamic_cache(self): with torch_export_patches(patch_transformers=True): @@ -94,7 +94,7 @@ def test_unflatten_flatten_dynamic_cache(self): self.assertIsInstance(c1, transformers.cache_utils.DynamicCache) unflat = flatten_unflatten_for_dynamic_shapes(c1) self.assertEqual( - "#2[#1[T1s4x4x4],#1[T1s4x4x4]]", self.string_type(unflat, with_shape=True) + "#2[T1s4x4x4,T1s4x4x4]", self.string_type(unflat, with_shape=True) ) self.assertEqual( "DynamicCache(key_cache=#1[T1s4x4x4], value_cache=#1[T1s4x4x4])", @@ -129,16 +129,15 @@ def test_unflatten_flatten_encoder_decoder_cache(self): self.assertIsInstance(unflat, list) self.assertEqual(len(unflat), 2) self.assertIsInstance(unflat[0], list) - self.assertEqual(len(unflat[0]), 2) - self.assertIsInstance(unflat[0][0], list) - self.assertEqual(len(unflat[0][0]), 3) + self.assertEqual(len(unflat[0]), 6) + self.assertIsInstance(unflat[0][0], torch.Tensor) self.assertEqual( - "#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]]", + "#6[T1s4x4x4,T1s4x4x4,T1s4x4x4,T1s4x4x4,T1s4x4x4,T1s4x4x4]", self.string_type(unflat[0], with_shape=True), ) self.assertEqual( - "#2[#2[#3[T1s4x4x4,T1s4x4x4,T1s4x4x4],#3[T1s4x4x4,T1s4x4x4,T1s4x4x4]]," - "#2[#3[T1s5x5x5,T1s5x5x5,T1s5x5x5],#3[T1s5x5x5,T1s5x5x5,T1s5x5x5]]]", + "#2[#6[T1s4x4x4,T1s4x4x4,T1s4x4x4,T1s4x4x4,T1s4x4x4,T1s4x4x4]," + "#6[T1s5x5x5,T1s5x5x5,T1s5x5x5,T1s5x5x5,T1s5x5x5,T1s5x5x5]]", self.string_type(unflat, with_shape=True), ) self.assertEqual( @@ -217,9 +216,9 @@ def test_unflatten_flatten_static_cache(self): self.assertEqual(len(flat), 6) unflat = flatten_unflatten_for_dynamic_shapes(c2) self.assertIsInstance(unflat, list) - self.assertEqual(len(unflat), 2) + self.assertEqual(len(unflat), 6) self.assertEqual( - "#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]", + "#6[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]", self.string_type(unflat, with_shape=True), ) @@ -256,9 +255,9 @@ def test_unflatten_flatten_hybrid_cache(self): self.assertEqual(len(flat), 6) unflat = flatten_unflatten_for_dynamic_shapes(c2) self.assertIsInstance(unflat, list) - self.assertEqual(len(unflat), 2) + self.assertEqual(len(unflat), 6) self.assertEqual( - "#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]", + "#6[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]", self.string_type(unflat, with_shape=True), ) diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index eab075db..59b4c829 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -181,7 +181,7 @@ def test_flatten(self): def test_flatten_cache(self): cache = make_dynamic_cache([(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)]) flat = flatten_object(cache, drop_keys=True) - self.assertEqual(string_type(flat), "(T1r4,T1r4)") + self.assertEqual(string_type(flat), "#2[T1r4,T1r4]") cache = dict( cache=make_dynamic_cache( [(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)] diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 903b4d68..5f9f9e8c 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -168,8 +168,8 @@ def test_automatic_speech_recognition_float16(self): "cache_position": {0: "seq_length"}, "encoder_outputs": [{0: "batch"}], "past_key_values": [ - [[{0: "batch"}, {0: "batch"}], [{0: "batch"}, {0: "batch"}]], - [[{0: "batch"}, {0: "batch"}], [{0: "batch"}, {0: "batch"}]], + [{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}], + [{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}], ], }, ds, diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index da4cbd91..4b1dddf4 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -76,10 +76,7 @@ def forward(self, x, cache): ep = torch.export.export( model, inputs, - dynamic_shapes=( - {0: DYN, 2: DYN}, - [[{0: DYN, 2: DYN}], [{0: DYN, 2: DYN}]], - ), + dynamic_shapes=({0: DYN, 2: DYN}, [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]), strict=False, ) mod = ep.module() diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py index 8a07ba3b..91d148a9 100644 --- a/onnx_diagnostic/export/shape_helper.py +++ b/onnx_diagnostic/export/shape_helper.py @@ -303,8 +303,8 @@ def make_fake_with_dynamic_dimensions( f"Une more recent version of transformers (>=4.55), " f"'layers' not found in class {type(x)}" ) - assert ( - isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2 + assert isinstance(dynamic_shapes, list) and ( + not dynamic_shapes or not isinstance(dynamic_shapes[0], list) ), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache" for il, layer in enumerate(x.layers): assert hasattr(layer, "keys") and hasattr(layer, "values"), ( @@ -312,10 +312,10 @@ def make_fake_with_dynamic_dimensions( f"not found in class {type(layer)} ({dir(layer)})" ) layer.keys = make_fake_with_dynamic_dimensions( - layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il] + layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[il * 2] )[0] layer.values = make_fake_with_dynamic_dimensions( - layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il] + layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[il * 2 + 1] )[0] return x, fake_mode if x.__class__.__name__ == "EncoderDecoderCache": From 38fe3d6b211a8ba96822f964c6bc5ae819f4ce47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 17:32:33 +0100 Subject: [PATCH 06/20] fix modelbuilder --- _unittests/ut_export/test_dynamic_shapes.py | 54 +++++++++++++++++++ .../test_patch_inputs.py | 4 +- onnx_diagnostic/export/dynamic_shapes.py | 3 +- onnx_diagnostic/helpers/rt_helper.py | 45 ---------------- onnx_diagnostic/tasks/image_text_to_text.py | 13 +++-- .../torch_export_patches/patch_inputs.py | 2 +- onnx_diagnostic/torch_models/validate.py | 11 +--- 7 files changed, 70 insertions(+), 62 deletions(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 089b3f23..2b181961 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -862,6 +862,60 @@ def test_unbatch_inputs(self): s, ) + def test_guess_dynamic_cache_without_patches(self): + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + cache = make_dynamic_cache( + [ + (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim)) + for i in range(n_layers) + ] + ) + z = torch.randn((1, 1, 1, 7)) + cache2 = make_dynamic_cache( + [ + ( + torch.randn(bsize + 1, nheads, slen + 1, dim + 1), + torch.randn(bsize + 1, nheads, slen + 1, dim + 1), + ) + for i in range(n_layers) + ] + ) + inputs = [ + (cache, z), + (cache2, torch.randn((1, 1, 1, 8))), + ] + + class Model(torch.nn.Module): + def forward(self, cache, z): + cache = CacheKeyValue(cache) + return ( + z + + cache.key_cache[0] + + cache.key_cache[1] + + cache.value_cache[0] + + cache.value_cache[1] + ) + + mi = ModelInputs(Model(), inputs) + ds = mi.guess_dynamic_shapes() + DYN = torch.export.Dim.DYNAMIC + self.assertEqual( + ( + ( + [ + {0: DYN, 2: DYN, 3: DYN}, + {0: DYN, 2: DYN, 3: DYN}, + {0: DYN, 2: DYN, 3: DYN}, + {0: DYN, 2: DYN, 3: DYN}, + ], + {3: DYN}, + ), + {}, + ), + ds, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_inputs.py b/_unittests/ut_torch_export_patches/test_patch_inputs.py index d454651d..b51e136a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_inputs.py +++ b/_unittests/ut_torch_export_patches/test_patch_inputs.py @@ -88,8 +88,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self): ) self.assertEqual( [ - [{0: "batch_size", 2: "past_sequence_length"}], - [{0: "batch_size", 2: "past_sequence_length"}], + {0: "batch_size", 2: "past_sequence_length"}, + {0: "batch_size", 2: "past_sequence_length"}, ], res[2]["past_key_values"], ) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index b2a04421..7ecfb9f8 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -1,4 +1,5 @@ import inspect +import itertools from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -934,7 +935,7 @@ def guess_dynamic_shape_object( auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc", ) ) - return [key_cache, value_cache] + return list(itertools.chain.from_iterable(zip(key_cache, value_cache))) raise NotImplementedError( f"Unable to build dynamic shapes for type {set_types.pop()}: " diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 763c5401..5aa896f4 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -96,54 +96,9 @@ def make_feeds( elif isinstance(i, float): i = np.array(i, dtype=np.float32) new_flat.append(i) - - # NOTE: model builder has a different order for past_key_values - # we need to reorder them to match the expected order - if is_modelbuilder: - # We assume that if "past_key_values" is in the names when it's - # modelbuilder - non_past_kv_input_names = [n for n in names if "past_key_values" not in n] - past_kv_names = [n for n in names if "past_key_values" in n] - reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names) - names = non_past_kv_input_names + reorder_past_kv_names return dict(zip(names, new_flat)) -def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]: - """ - Reorders the past_kvs for ModelBuilder to match the expected order - by PyTorch exported models. - - .. note:: - This function can take either the names or the actual tensors - as long as they are in a list. - - Conceptually, - - From:: - - [past_key_values.0.key, past_key_values.0.value, - past_key_values.1.key, past_key_values.1.value, ...] - - To:: - - [past_key_values.0.key, past_key_values.1.key, - ..., past_key_values.0.value, past_key_values.1.value, ...] - - :param past_kv: list of flattened inputs - :return: reordered list of flattened inputs - """ - total_len = len(past_kv) - if total_len % 2 != 0: - raise ValueError("The length of past_key_values should be even.") - keys = [] - values = [] - for i in range(0, total_len, 2): - keys.append(past_kv[i]) - values.append(past_kv[i + 1]) - return keys + values - - def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int: if isinstance(s, int): return s diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 3ad79109..653613f1 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -1,3 +1,4 @@ +import itertools from typing import Any, Callable, Dict, Optional, Tuple import torch from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache @@ -269,10 +270,14 @@ def get_inputs_default( "token_type_ids": {0: batch, 1: seq_length}, "attention_mask": {0: batch, 1: "cache+seq"}, "position_ids": {0: batch, 1: seq_length}, - "past_key_values": [ - [{0: batch} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - ], + "past_key_values": list( + itertools.chain.from_iterable( + zip( + [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + ) + ) + ), "pixel_values": ( {0: batch, 1: images} if model.__class__.__name__ == "IdeficsForVisionText2Text" diff --git a/onnx_diagnostic/torch_export_patches/patch_inputs.py b/onnx_diagnostic/torch_export_patches/patch_inputs.py index 3c93a70d..72f55dd9 100644 --- a/onnx_diagnostic/torch_export_patches/patch_inputs.py +++ b/onnx_diagnostic/torch_export_patches/patch_inputs.py @@ -38,7 +38,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any: for v in subset.values(): axes = v break - new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]] + new_shape = [axes for i in range(cache_length * 2)] return new_shape if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: raise NotImplementedError( diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 63340a8d..85dba7f9 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -12,7 +12,7 @@ from ..export import CoupleInputsDynamicShapes from ..helpers import max_diff, string_type, string_diff from ..helpers.helper import flatten_object -from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch +from ..helpers.rt_helper import make_feeds from ..helpers.torch_helper import to_any, torch_deepcopy from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes from ..tasks import random_input_kwargs @@ -1478,7 +1478,7 @@ def _mk(key, flavour=flavour): data[k_input], use_numpy=True, check_flatten=False, - is_modelbuilder=data["exporter"] == "modelbuilder", + is_modelbuilder=data["exporter"] == "modelbuilder", # to remove position_ids ) if verbose: print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}") @@ -1501,13 +1501,6 @@ def _mk(key, flavour=flavour): repeat=repeat, warmup=warmup, ) - # NOTE: modelbuilder has different order on past_kv outputs - if data["exporter"] == "modelbuilder": - logits = got[:1] - past_key_values = got[1:] - reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values) - got = logits + reorder_past_key_values - if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary: return summary, data From 1043a16f6860de8dd9edc5c749876de737e31fe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 17:50:21 +0100 Subject: [PATCH 07/20] disable two ewemples --- _unittests/ut_helpers/test_rt_helper.py | 43 +++++++++++++++---- .../test_documentation_examples.py | 21 +++++---- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index aa0f0155..107d1a9f 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -2,7 +2,10 @@ import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.helpers import max_diff, flatten_object from onnx_diagnostic.helpers.rt_helper import onnx_generate +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.export.api import to_onnx @@ -10,33 +13,51 @@ class TestRtSession(ExtTestCase): def simple_generate_with_cache( - self, model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100 + self, + model, + input_ids: torch.Tensor, + eos_token_id: int, + session: InferenceSessionForTorch, + max_new_tokens: int = 100, ): # First call: prefill outputs = model( input_ids, + use_cache=True, attention_mask=torch.ones( input_ids.shape, dtype=input_ids.dtype, device=input_ids.device ), - use_cache=True, ) # Next calls: decode for _ in range(max_new_tokens): next_token_logits = outputs.logits[:, -1, :] - past_key_values = outputs.past_key_values next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) if next_token_id.item() == eos_token_id: break input_ids = torch.cat([input_ids, next_token_id], dim=-1) + attention_mask = torch.ones( + input_ids.shape, dtype=input_ids.dtype, device=input_ids.device + ) + feeds = dict( + zip( + session.input_names, + torch_deepcopy( + flatten_object( + [next_token_id, attention_mask, outputs.past_key_values] + ) + ), + ) + ) + onnx_results = session.run(None, feeds) outputs = model( next_token_id, use_cache=True, - past_key_values=past_key_values, - attention_mask=torch.ones( - input_ids.shape, dtype=input_ids.dtype, device=input_ids.device - ), + past_key_values=outputs.past_key_values, + attention_mask=attention_mask, ) + diff = max_diff(outputs, onnx_results) + print("****", diff) return input_ids @hide_stdout() @@ -63,14 +84,18 @@ def test_onnx_generate(self): ) print("-- test_onnx_generate: generate") - res = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10) + res, session = onnx_generate( + model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True + ) n_inputs = input_ids.shape[1] self.assertEqualArray(input_ids[:1], res[:, :n_inputs]) self.assertEqual(res.dtype, torch.int64) self.assertEqual(res.shape, (1, 13)) print("-- test_onnx_generate: done") # expected = model.generate(input_ids[:1], max_new_tokens=10) - expected = self.simple_generate_with_cache(model, input_ids[:1], 2, max_new_tokens=10) + expected = self.simple_generate_with_cache( + model, input_ids[:1], 2, max_new_tokens=10, session=session + ) self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) print("******", res) print("******", expected) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index f9a08f1d..910fe21f 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -84,6 +84,8 @@ def add_test_methods(cls): if not reason and not has_dot and name in {"plot_dump_intermediate_results.py"}: reason = "dot not installed" + # transformers + if ( not reason and name in {"plot_export_tiny_llm.py"} @@ -98,13 +100,23 @@ def add_test_methods(cls): ): reason = "transformers<4.52" + if ( + not reason + and name in {"plot_export_with_dynamic_cache.py", "plot_export_tiny_phi2.py"} + and not has_transformers("4.55") + ): + reason = "transformers<4.55" + + # pytorch + if ( not reason and name in { + "plot_export_hub_codellama.py", "plot_export_locate_issue.py", "plot_export_with_auto.py", - "plot_export_hub_codellama.py", + "plot_export_tiny_llm.py", } and not has_torch("2.8") ): @@ -117,13 +129,6 @@ def add_test_methods(cls): ): reason = "unstable, let's wait for the next version" - if ( - not reason - and name in {"plot_export_tiny_phi2.py"} - and not has_transformers("4.55") - ): - reason = "unstable, let's wait for the next version" - if not reason and name in { "plot_export_tiny_llm_dim01.py", "plot_export_tiny_llm_dim01_onnx.py", From e89379d57ec1b218cddefcc5df894ca518e072e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 28 Oct 2025 18:49:35 +0100 Subject: [PATCH 08/20] fix some issues --- _unittests/ut_helpers/test_rt_helper.py | 97 +++++++++++++++++-------- onnx_diagnostic/helpers/helper.py | 46 ++++-------- onnx_diagnostic/helpers/rt_helper.py | 33 +++++++-- 3 files changed, 108 insertions(+), 68 deletions(-) diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index 107d1a9f..54648918 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -1,9 +1,14 @@ import os import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) from onnx_diagnostic.helpers import max_diff, flatten_object -from onnx_diagnostic.helpers.rt_helper import onnx_generate +from onnx_diagnostic.helpers.rt_helper import onnx_generate, make_empty_cache from onnx_diagnostic.helpers.torch_helper import torch_deepcopy from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs @@ -21,16 +26,33 @@ def simple_generate_with_cache( max_new_tokens: int = 100, ): # First call: prefill - outputs = model( - input_ids, - use_cache=True, - attention_mask=torch.ones( - input_ids.shape, dtype=input_ids.dtype, device=input_ids.device + attention_mask = torch.ones( + input_ids.shape, dtype=input_ids.dtype, device=input_ids.device + ) + feeds = { + **dict(zip(session.input_names[:2], [input_ids, attention_mask])), + **make_empty_cache( + input_ids.shape[0], + session.input_names[2:], + session.input_shapes[2:], + session.input_types[2:], ), + } + onnx_results = session.run(None, feeds) + + outputs = model(input_ids, use_cache=True, attention_mask=attention_mask) + + diff = max_diff(outputs, onnx_results) + assert diff["abs"] <= 0.1, ( + f"Unexpected issue with {type(model)}\ndiff={diff}" + f"\ninput_ids.shape={input_ids.shape}" + f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}" + f"\n got=\n" + f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}" ) # Next calls: decode - for _ in range(max_new_tokens): + for iteration in range(max_new_tokens): next_token_logits = outputs.logits[:, -1, :] next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) if next_token_id.item() == eos_token_id: @@ -42,11 +64,14 @@ def simple_generate_with_cache( feeds = dict( zip( session.input_names, - torch_deepcopy( - flatten_object( - [next_token_id, attention_mask, outputs.past_key_values] + [ + t.detach() + for t in torch_deepcopy( + flatten_object( + [next_token_id, attention_mask, outputs.past_key_values] + ) ) - ), + ], ) ) onnx_results = session.run(None, feeds) @@ -57,9 +82,17 @@ def simple_generate_with_cache( attention_mask=attention_mask, ) diff = max_diff(outputs, onnx_results) - print("****", diff) + assert diff["abs"] <= 0.1, ( + f"Unexpected issue with {type(model)}, iteration={iteration}" + f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}" + f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}" + f"\n got=\n" + f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}" + ) return input_ids + @requires_transformers("4.55") + @requires_torch("2.9") @hide_stdout() def test_onnx_generate(self): mid = "arnir0/Tiny-LLM" @@ -83,25 +116,25 @@ def test_onnx_generate(self): exporter="custom", ) - print("-- test_onnx_generate: generate") - res, session = onnx_generate( - model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True - ) - n_inputs = input_ids.shape[1] - self.assertEqualArray(input_ids[:1], res[:, :n_inputs]) - self.assertEqual(res.dtype, torch.int64) - self.assertEqual(res.shape, (1, 13)) - print("-- test_onnx_generate: done") - # expected = model.generate(input_ids[:1], max_new_tokens=10) - expected = self.simple_generate_with_cache( - model, input_ids[:1], 2, max_new_tokens=10, session=session - ) - self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) - print("******", res) - print("******", expected) - self.assertEqual(expected.dtype, torch.int64) - self.assertEqual(expected.shape, (1, 13)) - self.assertEqualArray(expected, res) + print("-- test_onnx_generate: generate") + res, session = onnx_generate( + model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True + ) + n_inputs = input_ids.shape[1] + self.assertEqualArray(input_ids[:1], res[:, :n_inputs]) + self.assertEqual(res.dtype, torch.int64) + self.assertEqual(res.shape, (1, 13)) + print("-- test_onnx_generate: done") + # expected = model.generate(input_ids[:1], max_new_tokens=10) + expected = self.simple_generate_with_cache( + model, input_ids[:1], 2, max_new_tokens=10, session=session + ) + self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) + print("******", res) + print("******", expected) + self.assertEqual(expected.dtype, torch.int64) + self.assertEqual(expected.shape, (1, 13)) + self.assertEqualArray(expected, res) if __name__ == "__main__": diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index d140c890..4b30fcca 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1057,6 +1057,20 @@ def max_diff( allow_unique_tensor_with_list_of_one_element=False, hist=hist, ) + + if expected.__class__.__name__ == "CausalLMOutputWithPast": + if verbose >= 6: + print( + f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} " + f"? {string_type(got)}" + ) + return max_diff( + [expected.logits, *flatten_object(expected.past_key_values)], + got, + debug_info=_debug(expected.__class__.__name__), + **_dkws, + ) + if hasattr(expected, "to_tuple"): if verbose >= 6: print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}") @@ -1067,36 +1081,6 @@ def max_diff( print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}") return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws) - if isinstance(got, (list, tuple)): - if len(got) != 1: - if verbose >= 6: - print( - f"[max_diff] list,tuple,2: {string_type(expected)} " - f"? {string_type(got)}" - ) - if verbose > 2: - import torch - - print( - f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, " - f"len(got)={len(got)}, level={level}, _index={_index}" - ) - for i, (a, b) in enumerate(zip(expected, got)): - if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): - print( - f" i={i} expected {a.dtype}:{a.shape}, " - f"has {b.dtype}:{b.shape}, _index={_index}" - ) - else: - print( - f" i={i} a is {type(a)}, " - f"b is {type(b)}, _index={_index}" - ) - return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) - if verbose >= 6: - print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}") - return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws) - if isinstance(expected, (tuple, list)): if verbose >= 6: print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}") @@ -1485,7 +1469,7 @@ def max_diff( return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if verbose >= 6: print( - f"[max_diff] {expected.__class__.__name__}: " + f"[max_diff*] {expected.__class__.__name__}: " f"{string_type(expected)} ? {string_type(got)}" ) expected_args, _spec = torch.utils._pytree.tree_flatten(expected) diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 5aa896f4..9ce89eb1 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -122,6 +122,31 @@ def rt_type_to_torch_dtype(typename: str) -> torch.dtype: return _DTYPES[typename] +def make_empty_cache( + batch: int, + onnx_input_names: List[str], + onnx_input_shapes: List[Tuple[Union[int, str], ...]], + onnx_input_types: List[str], +) -> Dict[str, torch.Tensor]: + """ + Creates an empty cache. Example: + + .. code-block:: python + + make_empty_cache( + 1, + sess.input_names[2:], + [i.shape for i in sess.get_inputs()[2:]], + [i.type for i in sess.get_inputs()[2:]], + ) + """ + feeds = {} + for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types): + new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape)) + feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype)) + return feeds + + def onnx_generate( model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch], input_ids: torch.Tensor, @@ -166,12 +191,10 @@ def onnx_generate( attention_mask=torch.ones( input_ids.shape, dtype=input_ids.dtype, device=input_ids.device ), + **make_empty_cache( + input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:] + ), ) - for name, shape, dtype in zip(input_names[2:], input_shapes[2:], input_types[2:]): - new_shape = tuple( - _get_dim(i, s, batch=input_ids.shape[0]) for i, s in enumerate(shape) - ) - feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype)) outputs = session.run(None, feeds) From bfac0a1673cbe7dca2e8258631034713c537096c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 11:18:05 +0100 Subject: [PATCH 09/20] fix caches --- _unittests/ut_helpers/test_helper.py | 55 ++++++++++++++++++- .../ut_torch_models/test_validate_models.py | 4 +- onnx_diagnostic/helpers/cache_helper.py | 25 +++++++-- onnx_diagnostic/helpers/helper.py | 7 +++ 4 files changed, 81 insertions(+), 10 deletions(-) diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index 59b4c829..d1bde2db 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -10,6 +10,7 @@ skipif_ci_windows, hide_stdout, requires_onnx, + requires_transformers, ) from onnx_diagnostic.helpers.helper import ( string_type, @@ -40,7 +41,13 @@ onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype, ) -from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache +from onnx_diagnostic.helpers.cache_helper import ( + make_dynamic_cache, + make_encoder_decoder_cache, + make_static_cache, + make_hybrid_cache, + make_sliding_window_cache, +) from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config @@ -584,11 +591,55 @@ def test_flatten_encoder_decoder_cache(self): s = string_type(inputs) self.assertIn("EncoderDecoderCache", s) - def test_string_typeçconfig(self): + def test_string_type_config(self): conf = get_pretrained_config("microsoft/phi-2", use_only_preinstalled=True) s = string_type(conf) self.assertStartsWith("PhiConfig(**{", s) + @requires_transformers("4.55") + def test_max_diff_causal_output(self): + from transformers.modeling_outputs import CausalLMOutputWithPast + + logits = torch.rand((3, 4)) + cache = make_dynamic_cache([(torch.rand((3, 4)), torch.rand((3, 4)))]) + out1 = CausalLMOutputWithPast(logits=logits, past_key_values=cache) + out2 = CausalLMOutputWithPast(logits=logits, past_key_values=cache) + self.assertEqual(max_diff(out1, out2)["abs"], 0) + self.assertEqual( + max_diff(out1, [logits, cache.layers[0].keys, cache.layers[0].values])["abs"], 0 + ) + + def test_max_diff_others(self): + t = torch.rand((3, 4)) + self.assertEqual(max_diff(t, t)["abs"], 0) + self.assertEqual(max_diff([t], [t])["abs"], 0) + self.assertEqual(max_diff([t], (t,))["abs"], 0) + self.assertEqual(max_diff((t,), [t])["abs"], 0) + self.assertEqual(max_diff((t,), (t,))["abs"], 0) + self.assertEqual(max_diff({"t": t}, {"t": t})["abs"], 0) + + def test_max_diff_caches(self): + cache = make_dynamic_cache([(torch.rand((3, 4)), torch.rand((3, 4)))]) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + cache = make_static_cache( + [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))], max_cache_len=3 + ) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + cache = make_sliding_window_cache( + [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))] + ) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + cache = make_encoder_decoder_cache(cache, cache) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + + def test_max_diff_caches_flat(self): + data = [(torch.rand((3, 4)), torch.rand((3, 4)))] + cache1 = make_dynamic_cache(data) + cache2 = make_dynamic_cache([*data[0]]) + self.assertEqual(max_diff(cache1, cache2)["abs"], 0) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index 6bbf60ee..e70f8331 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -39,7 +39,7 @@ def test_validate_tiny_llms_bfloat16(self): self.assertIn("onnx_filename", data) @requires_transformers("4.53") - @requires_torch("2.7.99") + @requires_torch("2.8.99") @requires_experimental() @hide_stdout() def test_validate_microsoft_phi4_reasoning(self): @@ -60,7 +60,7 @@ def test_validate_microsoft_phi4_reasoning(self): self.assertIn("onnx_filename", data) @requires_transformers("4.53") - @requires_torch("2.7.99") + @requires_torch("2.8.99") @requires_experimental() @hide_stdout() def test_validate_microsoft_phi3_mini_128k(self): diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index f3fdc5ef..3d06d4b4 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import packaging.version as pv import torch import transformers @@ -152,10 +152,18 @@ def make_dynamic_shapes_kv_cache( return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)] +def _preprocess_key_value_pairs( + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + if not key_value_pairs or isinstance(key_value_pairs[0], tuple): + return key_value_pairs + return list(zip(key_value_pairs[::2], key_value_pairs[1::2])) + + if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -191,6 +199,7 @@ def make_dynamic_cache( ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions are supported. """ + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) if ( key_value_pairs and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor) @@ -230,7 +239,7 @@ def make_dynamic_cache( else: def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -262,6 +271,7 @@ def make_dynamic_cache( ) print(string_type(past_key_values, with_shape=True)) """ + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore for i, (key, value) in enumerate(key_value_pairs): cache.update(key, value, i) @@ -269,7 +279,7 @@ def make_dynamic_cache( def make_static_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], max_cache_len: Optional[int] = None, ) -> transformers.cache_utils.DynamicCache: """ @@ -302,6 +312,7 @@ def make_static_cache( ) print(string_type(past_key_values, with_shape=True)) """ + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) class _config: def __init__(self): @@ -444,9 +455,10 @@ def get_text_config(self, *args, **kwargs): def make_sliding_window_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], ) -> transformers.cache_utils.SlidingWindowCache: "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) class _config: def __init__(self): @@ -499,7 +511,7 @@ def get_text_config(self, *args, **kwargs): def make_hybrid_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], max_cache_len: Optional[int] = None, max_batch_size: Optional[int] = None, sliding_window: Optional[int] = None, @@ -584,6 +596,7 @@ def make_hybrid_cache( self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) """ + key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) layer_types = None if key_value_pairs: assert ( diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 4b30fcca..fa0e4168 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1064,6 +1064,13 @@ def max_diff( f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} " f"? {string_type(got)}" ) + if got.__class__.__name__ == "CausalLMOutputWithPast": + return max_diff( + [expected.logits, *flatten_object(expected.past_key_values)], + [got.logits, *flatten_object(got.past_key_values)], + debug_info=_debug(expected.__class__.__name__), + **_dkws, + ) return max_diff( [expected.logits, *flatten_object(expected.past_key_values)], got, From b7d5dd7eaf8b8821f119f5ec22b54d3667745785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 12:23:46 +0100 Subject: [PATCH 10/20] more tests --- _doc/examples/plot_export_hub_codellama.py | 4 +- _doc/examples/plot_export_tiny_phi2.py | 4 +- _unittests/ut_export/test_api.py | 59 ++++++++++++++++++- _unittests/ut_export/test_dynamic_shapes.py | 21 +++++++ .../ut_helpers/test_model_builder_helper.py | 4 +- _unittests/ut_helpers/test_rt_helper.py | 2 +- .../test_documentation_examples.py | 3 +- onnx_diagnostic/export/api.py | 38 +++++++++++- 8 files changed, 121 insertions(+), 14 deletions(-) diff --git a/_doc/examples/plot_export_hub_codellama.py b/_doc/examples/plot_export_hub_codellama.py index 56cfd747..5f82096a 100644 --- a/_doc/examples/plot_export_hub_codellama.py +++ b/_doc/examples/plot_export_hub_codellama.py @@ -22,9 +22,7 @@ from onnx_diagnostic import doc from onnx_diagnostic.ext_test_case import unit_test_going from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.torch_models.hghub import ( - get_untrained_model_with_inputs, -) +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_models.hghub.hub_api import ( get_model_info, get_pretrained_config, diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index b4979334..e3f138ad 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -33,9 +33,7 @@ from onnx_diagnostic.helpers.rt_helper import make_feeds from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -from onnx_diagnostic.torch_models.hghub import ( - get_untrained_model_with_inputs, -) +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs warnings.simplefilter("ignore") diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 75a8a255..f9113db8 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -1,6 +1,11 @@ import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.helpers import max_diff +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.helpers.rt_helper import make_feeds +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.export.api import to_onnx @@ -19,16 +24,66 @@ def forward(self, x, y): (x, y), dynamic_shapes=ds, exporter="custom", - filename=self.get_dump_file("custom.onnx"), + filename=self.get_dump_file("to_onnx_custom.onnx"), ) to_onnx( Model(), (x, y), dynamic_shapes=ds, exporter="onnx-dynamo", - filename=self.get_dump_file("onnx-dynamo.onnx"), + filename=self.get_dump_file("to_onnx_onnx-dynamo.onnx"), ) + @hide_stdout() + def test_tiny_llm_to_onnx(self): + import onnxruntime + + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + b1 = data["inputs_batch1"] + filenames = { + "custom": self.get_dump_file("test_tiny_llm_to_onnx-custom.onnx"), + "onnx-dynamo": self.get_dump_file("test_tiny_llm_to_onnx-dynamo.onnx"), + "modelbuilder": self.get_dump_file("model.onnx"), + } + del inputs["position_ids"] + del ds["position_ids"] + del b1["position_ids"] + + expected = model(**torch_deepcopy(b1)) + + with torch_export_patches(patch_transformers=True): + for exporter, filename in filenames.items(): + with self.subTest(exporter=exporter): + to_onnx( + model, + kwargs=inputs, + dynamic_shapes=ds, + exporter=exporter, + filename=filename, + ) + for exporter, filename in filenames.items(): + with self.subTest(exporter=f"validate-{exporter}"): + sess = onnxruntime.InferenceSession( + filename, providers=["CPUExecutionProvider"] + ) + feeds = make_feeds(sess, b1, use_numpy=True) + got = sess.run(None, feeds) + diff = max_diff(expected, got) + assert diff["abs"] <= 1e-5, f"diff={diff}" + + b1["attention_mask"][:, :] = 1 + expected = model(**torch_deepcopy(b1)) + for exporter, filename in filenames.items(): + with self.subTest(exporter=f"full-mask-{exporter}"): + sess = onnxruntime.InferenceSession( + filename, providers=["CPUExecutionProvider"] + ) + feeds = make_feeds(sess, b1, use_numpy=True) + got = sess.run(None, feeds) + diff = max_diff(expected, got) + assert diff["abs"] <= 1e-5, f"diff={diff}" + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 2b181961..f0b4619f 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -916,6 +916,27 @@ def forward(self, cache, z): ds, ) + def test_invalid_dimensions_for_export(self): + ags = [] + kws = dict( + input_ids=torch.randint(0, 10, (2, 3)), + attention_mask=torch.randint(0, 1, (2, 33)), + position_ids=torch.randint(0, 10, (2, 3)), + past_key_values=make_dynamic_cache( + [torch.rand((2, 1, 30, 96)), torch.rand((2, 1, 30, 96))] + ), + ) + ds = dict( + input_ids={0: "batch", 1: "seq_length"}, + attention_mask={0: "batch", 1: "seq_length"}, + position_ids={0: "batch", 1: "seq_length"}, + past_key_values=[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], + ) + with torch_export_patches(patch_transformers=True): + cpl = CoupleInputsDynamicShapes(ags, kws, ds) + backed_size_oblivious = cpl.invalid_dimensions_for_export() + self.assertFalse(backed_size_oblivious) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_model_builder_helper.py b/_unittests/ut_helpers/test_model_builder_helper.py index 94fe28f2..c61cabb3 100644 --- a/_unittests/ut_helpers/test_model_builder_helper.py +++ b/_unittests/ut_helpers/test_model_builder_helper.py @@ -12,9 +12,7 @@ create_model_builder, save_model_builder, ) -from onnx_diagnostic.torch_models.hghub import ( - get_untrained_model_with_inputs, -) +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.helpers.rt_helper import make_feeds diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index 54648918..6ce2a275 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -113,7 +113,7 @@ def test_onnx_generate(self): kwargs=inputs, dynamic_shapes=ds, filename=model_name, - exporter="custom", + exporter="modelbuilder", ) print("-- test_onnx_generate: generate") diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 910fe21f..8dfba825 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -102,7 +102,7 @@ def add_test_methods(cls): if ( not reason - and name in {"plot_export_with_dynamic_cache.py", "plot_export_tiny_phi2.py"} + and name in {"plot_export_tiny_phi2.py", "plot_export_with_dynamic_cache.py"} and not has_transformers("4.55") ): reason = "transformers<4.55" @@ -117,6 +117,7 @@ def add_test_methods(cls): "plot_export_locate_issue.py", "plot_export_with_auto.py", "plot_export_tiny_llm.py", + "plot_export_with_dynamic_cache.py", } and not has_torch("2.8") ): diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 18707c64..9943830a 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -42,7 +42,7 @@ def to_onnx( ), f"output_dynamic_shapes not supported for exporter={exporter!r}" epo = torch.onnx.export( mod, - args=args, + args=args or tuple(), kwargs=kwargs, input_names=input_names, output_names=output_names, @@ -54,4 +54,40 @@ def to_onnx( epo.save(filename) return epo + if exporter == "modelbuilder": + import os + from ..helpers import flatten_object, string_type + from ..helpers.model_builder_helper import create_model_builder, save_model_builder + + assert filename, f"filename must be specified for exporter={exporter!r}" + assert ( + not output_dynamic_shapes + ), f"output_dynamic_shapes not supported for exporter={exporter!r}" + assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}" + assert not args, f"only kwargs can be defined with exporter={exporter!r}" + assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( + f"Only a specified set of inputs is supported for exporter={exporter!r}, " + f"but it is {list(kwargs)}" + ) + flat_inputs = flatten_object(kwargs, drop_keys=True) + first = flat_inputs[0] + first_float = [ + t + for t in flat_inputs + if t.dtype in {torch.float32, torch.double, torch.float16, torch.bfloat16} + ] + assert first_float, ( + f"Unable to find a float tensor in the inputs " + f"{string_type(kwargs, with_shape=True)}" + ) + onx = create_model_builder( + mod.config, + mod, + precision=str(first_float[0].dtype).split(".")[-1], + execution_provider="cuda" if first.is_cuda else "cpu", + cache_dir=os.path.dirname(filename), + ) + save_model_builder(onx, os.path.dirname(filename)) + return onx + raise ValueError(f"Unknown exporter={exporter!r}") From 2b922187aa5ea66024bca993b74c379463a93828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 13:03:19 +0100 Subject: [PATCH 11/20] fix version --- _unittests/ut_xrun_doc/test_documentation_examples.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 8dfba825..f199b6ce 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -117,12 +117,18 @@ def add_test_methods(cls): "plot_export_locate_issue.py", "plot_export_with_auto.py", "plot_export_tiny_llm.py", - "plot_export_with_dynamic_cache.py", } and not has_torch("2.8") ): reason = "torch<2.8" + if ( + not reason + and name in {"plot_export_with_dynamic_cache.py"} + and not has_torch("2.9") + ): + reason = "does not work with 2.8" + if ( not reason and name in {"plot_dump_intermediate_results.py"} From 2523b0da68b3f9f1bbaaecc7b03013a4ea811393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 14:32:57 +0100 Subject: [PATCH 12/20] fix issues --- _doc/technical/plot_generate.py | 2 +- _unittests/ut_export/test_api.py | 23 +++- _unittests/ut_helpers/test_rt_helper.py | 8 +- .../test_patch_transformers.py | 126 ++++++++++++++++++ onnx_diagnostic/helpers/torch_helper.py | 8 +- .../patches/patch_transformers.py | 106 ++++++++++++--- 6 files changed, 246 insertions(+), 27 deletions(-) create mode 100644 _unittests/ut_torch_export_patches/test_patch_transformers.py diff --git a/_doc/technical/plot_generate.py b/_doc/technical/plot_generate.py index 0edab8e1..8ebcc4e0 100644 --- a/_doc/technical/plot_generate.py +++ b/_doc/technical/plot_generate.py @@ -155,7 +155,7 @@ def simple_generate_with_cache( dtype = get_weight_type(model) print("-- model dtype:", dtype) export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype) -exporter = "custom" if "custom" in sys.argv else "onnx-dynamo" +exporter = "onnx-dynamo" if "dynamo" in sys.argv else "custom" model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx" if not os.path.exists(model_name): # This step is slow so let's skip it if it was already done. diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index f9113db8..5685a6a2 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -1,9 +1,10 @@ import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers from onnx_diagnostic.helpers import max_diff from onnx_diagnostic.helpers.torch_helper import torch_deepcopy from onnx_diagnostic.helpers.rt_helper import make_feeds +from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.export.api import to_onnx @@ -46,6 +47,10 @@ def test_tiny_llm_to_onnx(self): "onnx-dynamo": self.get_dump_file("test_tiny_llm_to_onnx-dynamo.onnx"), "modelbuilder": self.get_dump_file("model.onnx"), } + if not has_transformers("4.55"): + # <4.55: torch._check(causal_mask.shape[3] != 33) + # torch._check(causal_mask.shape[3] == 33) + del filenames["onnx-dynamo"] del inputs["position_ids"] del ds["position_ids"] del b1["position_ids"] @@ -72,14 +77,24 @@ def test_tiny_llm_to_onnx(self): diff = max_diff(expected, got) assert diff["abs"] <= 1e-5, f"diff={diff}" - b1["attention_mask"][:, :] = 1 - expected = model(**torch_deepcopy(b1)) + problem = dict( + input_ids=torch.tensor([[24320]], dtype=torch.int64), + attention_mask=torch.tensor([[1, 1, 1, 1]], dtype=torch.int64), + past_key_values=make_dynamic_cache( + [ + torch.rand((1, 1, 3, 96), dtype=torch.float32), + torch.rand((1, 1, 3, 96), dtype=torch.float32), + ] + ), + ) + + expected = model(**torch_deepcopy(problem)) for exporter, filename in filenames.items(): with self.subTest(exporter=f"full-mask-{exporter}"): sess = onnxruntime.InferenceSession( filename, providers=["CPUExecutionProvider"] ) - feeds = make_feeds(sess, b1, use_numpy=True) + feeds = make_feeds(sess, problem, use_numpy=True) got = sess.run(None, feeds) diff = max_diff(expected, got) assert diff["abs"] <= 1e-5, f"diff={diff}" diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index 6ce2a275..a016480f 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -48,7 +48,8 @@ def simple_generate_with_cache( f"\ninput_ids.shape={input_ids.shape}" f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}" f"\n got=\n" - f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}" + f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}\n" + f"feeds={self.string_type(feeds, with_shape=True, with_min_max=True)}" ) # Next calls: decode @@ -87,7 +88,8 @@ def simple_generate_with_cache( f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}" f"\nexpected={self.string_type(outputs, with_shape=True, with_min_max=True)}" f"\n got=\n" - f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}" + f"{self.string_type(onnx_results, with_shape=True, with_min_max=True)}\n" + f"feeds={self.string_type(feeds, with_shape=True, with_min_max=True)}" ) return input_ids @@ -113,7 +115,7 @@ def test_onnx_generate(self): kwargs=inputs, dynamic_shapes=ds, filename=model_name, - exporter="modelbuilder", + exporter="custom", ) print("-- test_onnx_generate: generate") diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py new file mode 100644 index 00000000..a8a0287d --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -0,0 +1,126 @@ +import unittest +import torch +import transformers +import transformers.integrations.sdpa_attention as sdpa_attention +import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy + + +class TestPatchPatchTransformers(ExtTestCase): + @requires_transformers("4.55") + def test_sdpa_mask_recent_torch(self): + sdpa_mask_recent_torch = transformers.masking_utils.sdpa_mask_recent_torch + patched_sdpa_mask_recent_torch = patch_transformers.patched_sdpa_mask_recent_torch + kwargs = { + "batch_size": 1, + "cache_position": torch.tensor([3], dtype=torch.int64), + "kv_length": 4, + "kv_offset": 0, + "mask_function": transformers.masking_utils.causal_mask_function, + "attention_mask": torch.tensor([[True, True, True, True]]), + "local_size": None, + "allow_is_causal_skip": True, + "allow_is_bidirectional_skip": False, + } + expected = sdpa_mask_recent_torch(**kwargs) + got = patched_sdpa_mask_recent_torch(**kwargs) + self.assertEqual(expected, got) + + kwargs = { + "batch_size": 1, + "cache_position": torch.tensor([3], dtype=torch.int64), + "kv_length": 4, + "kv_offset": 0, + "mask_function": transformers.masking_utils.causal_mask_function, + "attention_mask": torch.tensor([[True, True, True, True]]), + "local_size": None, + "allow_is_causal_skip": False, + "allow_is_bidirectional_skip": False, + } + expected = sdpa_mask_recent_torch(**kwargs) + got = patched_sdpa_mask_recent_torch(**kwargs) + self.assertEqualArray(expected, got) + + @requires_transformers("4.55") + def test_sdpa_attention_forward_not_causal(self): + sdpa_attention_forward = sdpa_attention.sdpa_attention_forward + patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward + kwargs = { + "module": None, + "query": torch.rand((1, 2, 1, 96), dtype=torch.float32), + "key": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "value": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "attention_mask": None, + "attention_dropout": 0, + "scaling": 0.10206207261596575, + "is_causal": False, + } + expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + self.assertEqualArray(expected, got) + + kwargs = { + "module": None, + "query": torch.rand((1, 2, 1, 96), dtype=torch.float32), + "key": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "value": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "attention_mask": torch.tensor([[[[True, True, True, True]]]]), + "attention_dropout": 0, + "scaling": 0.10206207261596575, + "is_causal": False, + } + expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + self.assertEqualArray(expected, got) + + @requires_transformers("4.55") + def test_sdpa_attention_forward_causal(self): + sdpa_attention_forward = sdpa_attention.sdpa_attention_forward + patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward + kwargs = { + "module": None, + "query": torch.rand((1, 2, 1, 96), dtype=torch.float32), + "key": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "value": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "attention_mask": torch.tensor([[[[True, True, True, True]]]]), + "attention_dropout": 0, + "scaling": 0.10206207261596575, + "is_causal": True, + } + expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + self.assertEqualArray(expected, got) + + kwargs = { + "module": None, + "query": torch.rand((1, 2, 1, 96), dtype=torch.float32), + "key": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "value": torch.rand((1, 2, 4, 96), dtype=torch.float32), + "attention_mask": None, + "attention_dropout": 0, + "scaling": 0.10206207261596575, + "is_causal": True, + } + expected = sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + got = patched_sdpa_attention_forward(**torch_deepcopy(kwargs))[0] + self.assertEqualArray(expected, got) + + def test_causal_mask_in_scaled_dot_product_attention(self): + # see https://docs.pytorch.org/docs/stable/generated/... + # ...torch.nn.functional.scaled_dot_product_attention.html + + query = torch.rand((1, 2, 1, 96), dtype=torch.float32) + key = torch.rand((1, 2, 4, 96), dtype=torch.float32) + L, S = query.size(-2), key.size(-2) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + self.assertEqual(attn_bias.min().item(), 0) + attn_causal_bias = attn_bias.clone() + + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_causal_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + self.assertEqual(attn_causal_bias.min().item(), -float("inf")) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 65634548..2ff5c2f9 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -856,9 +856,15 @@ def torch_deepcopy(value: Any) -> Any: ), f"Unexpected type={type(value)}" return copy.deepcopy(value) + if hasattr(value, "__nocopy__"): + return value + # We should have a code using serialization, deserialization assuming a model # cannot be exported without them. - raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}") + raise NotImplementedError( + f"torch_deepcopy not implemented for type {type(value)}, " + f"add attribute '__nocopy__' to return it as is." + ) def torch_tensor_size(value: Any) -> Any: diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 2d8a72d3..5a5f17b4 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -39,19 +39,45 @@ except ImportError: patch_DynamicLayer = False -from ...ext_test_case import has_transformers -from ...helpers.torch_helper import is_torchdynamo_exporting -patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99") +def _has_transformers(version: str) -> bool: + return pv.Version(transformers.__version__) >= pv.Version(version) + + +def _is_torchdynamo_exporting() -> bool: + """ + Tells if :epkg:`torch` is exporting a model. + Relies on ``torch.compiler.is_exporting()``. + """ + import torch + + if not hasattr(torch.compiler, "is_exporting"): + # torch.compiler.is_exporting requires torch>=2.7 + return False + + try: + return torch.compiler.is_exporting() + except Exception: + try: + import torch._dynamo as dynamo + + return dynamo.is_exporting() # type: ignore + except Exception: + return False + + +patch_is_initialized = _has_transformers("4.56.99") if patch_masking_utils: # Introduced in 4.52 from transformers.masking_utils import ( + _ignore_causal_mask_sdpa, + _ignore_bidirectional_mask_sdpa, + and_masks, + bidirectional_mask_function, causal_mask_function, padding_mask_function, - and_masks, - _ignore_causal_mask_sdpa, prepare_padding_mask, ) @@ -98,7 +124,7 @@ def vector_mask_function( # for a, dims in zip(args, udimensions) # ] max_shape = tuple(args[i].shape[0] for i in indices) - # if is_torchdynamo_exporting(): + # if _is_torchdynamo_exporting(): # for a in args: # # The exporter should export with a dimension > 1 # # to make sure it is dynamic. @@ -151,6 +177,7 @@ def patched_sdpa_mask_recent_torch( attention_mask: Optional[torch.Tensor] = None, local_size: Optional[int] = None, allow_is_causal_skip: bool = True, + allow_is_bidirectional_skip: bool = False, **kwargs, ) -> Optional[torch.Tensor]: """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``.""" @@ -160,6 +187,25 @@ def patched_sdpa_mask_recent_torch( padding_mask, q_length, kv_length, kv_offset, local_size ): return None + if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): + return None + + if mask_function is bidirectional_mask_function: + if padding_mask is not None: + # used for slicing without data-dependent slicing + mask_indices = ( + torch.arange(kv_length, device=cache_position.device) + kv_offset + ) + return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1) + return torch.ones( + batch_size, + 1, + q_length, + kv_length, + dtype=torch.bool, + device=cache_position.device, + ) + kv_arange = torch.arange(kv_length, device=cache_position.device) kv_arange += kv_offset if padding_mask is not None: @@ -275,7 +321,7 @@ class patched_AttentionMaskConverter: """ # This method was fixed in 4.51 at least. - _PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else [] + _PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else [] _PATCHED_CLASS_ = AttentionMaskConverter @staticmethod @@ -507,7 +553,7 @@ def _cache_dependant_input_preparation( The current implementation does not rely on ``self`` and could be a class method. It is left as a standard method to be easily rewritten. """ - if is_torchdynamo_exporting(): + if _is_torchdynamo_exporting(): return self._cache_dependant_input_preparation_exporting( input_ids, inputs_embeds, cache_position ) @@ -1316,16 +1362,40 @@ def patched_sdpa_attention_forward( attention_mask is None or attention_mask.shape[3] == key.shape[2], "Attention mask shape incompatible with key shape.", ) - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=dropout, - scale=scaling, - is_causal=is_causal, - **sdpa_kwargs, - ) + if is_causal: + attn_output = torch.cond( + query.shape[2] > 1, # distinction between prefill and decoding steps + lambda query, key, value: torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + dropout_p=dropout, + scale=scaling, + is_causal=True, + **sdpa_kwargs, + ), + lambda query, key, value: torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + dropout_p=dropout, + scale=scaling, + is_causal=False, + **sdpa_kwargs, + ), + [query, key, value], + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + **sdpa_kwargs, + ) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None From c46fe1e871040129f0a4755e3f68663d5c6d9503 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 14:47:57 +0100 Subject: [PATCH 13/20] mypy --- onnx_diagnostic/export/api.py | 4 ++-- .../patches/patch_transformers.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 9943830a..6dcd1dd2 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -65,9 +65,9 @@ def to_onnx( ), f"output_dynamic_shapes not supported for exporter={exporter!r}" assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}" assert not args, f"only kwargs can be defined with exporter={exporter!r}" - assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( + assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( # type: ignore[arg-type] f"Only a specified set of inputs is supported for exporter={exporter!r}, " - f"but it is {list(kwargs)}" + f"but it is {list(kwargs)}" # type: ignore[arg-type] ) flat_inputs = flatten_object(kwargs, drop_keys=True) first = flat_inputs[0] diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 5a5f17b4..01664caa 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -73,7 +73,6 @@ def _is_torchdynamo_exporting() -> bool: # Introduced in 4.52 from transformers.masking_utils import ( _ignore_causal_mask_sdpa, - _ignore_bidirectional_mask_sdpa, and_masks, bidirectional_mask_function, causal_mask_function, @@ -81,6 +80,12 @@ def _is_torchdynamo_exporting() -> bool: prepare_padding_mask, ) + try: + # transformers>=5.0 + from transformers.masking_utils import _ignore_bidirectional_mask_sdpa + except ImportError: + _ignore_bidirectional_mask_sdpa = None + def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" from ...helpers import string_type @@ -187,7 +192,11 @@ def patched_sdpa_mask_recent_torch( padding_mask, q_length, kv_length, kv_offset, local_size ): return None - if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): + if ( + allow_is_bidirectional_skip + and _ignore_bidirectional_mask_sdpa + and _ignore_bidirectional_mask_sdpa(padding_mask) + ): return None if mask_function is bidirectional_mask_function: From 3c37d9d0e13afaca2ea2b51a22db5661c0aed3c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 15:06:23 +0100 Subject: [PATCH 14/20] import --- .../torch_export_patches/patches/patch_transformers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 01664caa..0ac41694 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -74,7 +74,6 @@ def _is_torchdynamo_exporting() -> bool: from transformers.masking_utils import ( _ignore_causal_mask_sdpa, and_masks, - bidirectional_mask_function, causal_mask_function, padding_mask_function, prepare_padding_mask, @@ -82,9 +81,13 @@ def _is_torchdynamo_exporting() -> bool: try: # transformers>=5.0 - from transformers.masking_utils import _ignore_bidirectional_mask_sdpa + from transformers.masking_utils import ( + _ignore_bidirectional_mask_sdpa, + bidirectional_mask_function, + ) except ImportError: _ignore_bidirectional_mask_sdpa = None + bidirectional_mask_function = None def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" From ce4fb6672976d74feb79d8f42baa6749ccb5372e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 15:24:51 +0100 Subject: [PATCH 15/20] fix issues --- _unittests/ut_xrun_doc/test_documentation_examples.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index f199b6ce..bda3cb20 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -102,11 +102,18 @@ def add_test_methods(cls): if ( not reason - and name in {"plot_export_tiny_phi2.py", "plot_export_with_dynamic_cache.py"} + and name in {"plot_export_tiny_phi2.py"} and not has_transformers("4.55") ): reason = "transformers<4.55" + if ( + not reason + and name in {"plot_export_with_dynamic_cache.py"} + and not has_transformers("4.56") + ): + reason = "transformers<4.56" + # pytorch if ( From 37bdbd617624359096728d4380400a42e9ec75d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 29 Oct 2025 19:46:55 +0100 Subject: [PATCH 16/20] onnx_generate_with_genai --- _doc/technical/plot_generate.py | 3 + _unittests/ut_helpers/test_rt_helper.py | 59 ++++++---- onnx_diagnostic/ext_test_case.py | 11 ++ .../helpers/model_builder_helper.py | 104 +++++++++++++++++- onnx_diagnostic/helpers/rt_helper.py | 74 ++++++++++++- requirements-dev.txt | 1 + 6 files changed, 229 insertions(+), 23 deletions(-) diff --git a/_doc/technical/plot_generate.py b/_doc/technical/plot_generate.py index 8ebcc4e0..bdda90c0 100644 --- a/_doc/technical/plot_generate.py +++ b/_doc/technical/plot_generate.py @@ -94,6 +94,9 @@ # %% # Custom method generate # ====================== +# +# Let's implement a simple function replicating when method +# ``generate`` does. def simple_generate_with_cache( diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index a016480f..69527d09 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -3,12 +3,17 @@ import torch from onnx_diagnostic.ext_test_case import ( ExtTestCase, + has_onnxruntime_genai, hide_stdout, requires_transformers, requires_torch, ) from onnx_diagnostic.helpers import max_diff, flatten_object -from onnx_diagnostic.helpers.rt_helper import onnx_generate, make_empty_cache +from onnx_diagnostic.helpers.rt_helper import ( + onnx_generate, + onnx_generate_with_genai, + make_empty_cache, +) from onnx_diagnostic.helpers.torch_helper import torch_deepcopy from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs @@ -101,6 +106,7 @@ def test_onnx_generate(self): print("-- test_onnx_generate: get model") data = get_untrained_model_with_inputs(mid) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + configuration = data["configuration"] del inputs["position_ids"] del ds["position_ids"] input_ids = inputs["input_ids"] @@ -118,25 +124,38 @@ def test_onnx_generate(self): exporter="custom", ) - print("-- test_onnx_generate: generate") - res, session = onnx_generate( - model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True - ) - n_inputs = input_ids.shape[1] - self.assertEqualArray(input_ids[:1], res[:, :n_inputs]) - self.assertEqual(res.dtype, torch.int64) - self.assertEqual(res.shape, (1, 13)) - print("-- test_onnx_generate: done") - # expected = model.generate(input_ids[:1], max_new_tokens=10) - expected = self.simple_generate_with_cache( - model, input_ids[:1], 2, max_new_tokens=10, session=session - ) - self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) - print("******", res) - print("******", expected) - self.assertEqual(expected.dtype, torch.int64) - self.assertEqual(expected.shape, (1, 13)) - self.assertEqualArray(expected, res) + print("-- test_onnx_generate: generate") + res, session = onnx_generate( + model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True + ) + n_inputs = input_ids.shape[1] + self.assertEqualArray(input_ids[:1], res[:, :n_inputs]) + self.assertEqual(res.dtype, torch.int64) + self.assertEqual(res.shape, (1, 13)) + print("-- test_onnx_generate: done") + # expected = model.generate(input_ids[:1], max_new_tokens=10) + expected = self.simple_generate_with_cache( + model, input_ids[:1], 2, max_new_tokens=10, session=session + ) + self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) + print("******", res) + print("******", expected) + self.assertEqual(expected.dtype, torch.int64) + self.assertEqual(expected.shape, (1, 13)) + self.assertEqualArray(expected, res) + + if not has_onnxruntime_genai(): + raise unittest.SkipTest("onnxruntime_genai is missing") + + res, session = onnx_generate_with_genai( + model_name, + input_ids[:1], + max_new_tokens=10, + return_session=True, + transformers_config=configuration, + ) + self.assertNotEmpty(session) + self.assertEqualArray(expected, res) if __name__ == "__main__": diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index cd68d7b2..d9d5f4a3 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -630,6 +630,17 @@ def has_onnxruntime_training(push_back_batch: bool = False): return True +def has_onnxruntime_genai(): + """Tells if onnxruntime_genai is installed.""" + try: + import onnxruntime_genai # noqa: F401 + + return True + except ImportError: + # onnxruntime not training + return False + + def requires_onnxruntime_training( push_back_batch: bool = False, ortmodule: bool = False, msg: str = "" ) -> Callable: diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 8ee33abe..7f315df6 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -1,11 +1,12 @@ +import copy import importlib.util import os import requests import sys from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from urllib.parse import urlparse -from onnx import ModelProto, TensorProto +from onnx import ModelProto, TensorProto, load as load_model CACHE_SUBDIR = "onnx-diagnostic" @@ -337,3 +338,102 @@ def _post(onnx_model): # onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir) # onnx_model.save_processing(hf_name, extra_kwargs, output_dir) return onnx_model + + +def make_genai_config( + config, + onnx_filename: str, +) -> Dict: + """ + Creates genai config file for a model. + + :param config: configuration from transformers + :param onnx_filename: onnx configuration + :return: configuration + """ + onx = load_model(onnx_filename, load_external_data=False) + config = copy.deepcopy(config) + defaults = { + "bos_token_id": None, + "do_sample": False, + "eos_token_id": None, + "pad_token_id": None, + "temperature": 1.0, + "top_k": 50, + "top_p": 1.0, + } + for key, default_val in defaults.items(): + if not hasattr(config, key): + setattr(config, key, default_val) + + bos_token_id = ( + config.bos_token_id + if hasattr(config, "bos_token_id") and config.bos_token_id is not None + else 1 + ) + eos_token_id = config.eos_token_id + pad_token_id = ( + config.pad_token_id + if hasattr(config, "pad_token_id") and config.pad_token_id is not None + else ( + config.eos_token_id[0] + if isinstance(config.eos_token_id, list) + else config.eos_token_id + ) + ) + input_names = [i.name for i in onx.graph.input] + output_names = [i.name for i in onx.graph.output] + past_key_values = [s for s in input_names if s.startswith("past_key_value")] + first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015 + shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim) + return { + "model": { + "bos_token_id": bos_token_id, + "context_length": config.max_position_embeddings, + "decoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": [], + }, + "filename": onnx_filename, + "head_size": shape[-1], + "hidden_size": config.hidden_size, + "inputs": input_names, + "outputs": output_names, + "num_attention_heads": config.num_attention_heads, + "num_hidden_layers": len(past_key_values) // 2, + "num_key_value_heads": shape[1], + }, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + # "type": self.model_type[ : self.model_type.find("For") + # if "For" in self.model_type else len(self.model_type)].lower(), + "vocab_size": config.vocab_size, + }, + "search": { + "diversity_penalty": ( + config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0 + ), + "do_sample": config.do_sample if hasattr(config, "do_sample") else False, + "early_stopping": True, + "length_penalty": ( + config.length_penalty if hasattr(config, "length_penalty") else 1.0 + ), + "max_length": config.max_position_embeddings, + "min_length": 0, + "no_repeat_ngram_size": ( + config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0 + ), + "num_beams": config.num_beams if hasattr(config, "num_beams") else 1, + "num_return_sequences": ( + config.num_return_sequences if hasattr(config, "num_return_sequences") else 1 + ), + "past_present_share_buffer": False, + "repetition_penalty": ( + config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0 + ), + "temperature": config.temperature if hasattr(config, "temperature") else 1.0, + "top_k": config.top_k if hasattr(config, "top_k") else 50, + "top_p": config.top_p if hasattr(config, "top_p") else 1.0, + }, + } diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 9ce89eb1..bfe16900 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Tuple, Union +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import onnx import torch @@ -224,3 +226,73 @@ def onnx_generate( if return_session: return input_ids, session return input_ids + + +def onnx_generate_with_genai( + model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch], + input_ids: torch.Tensor, + max_new_tokens=100, + return_session: bool = False, + transformers_config: Optional[Any] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]: + """ + Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate`` + for an ONNX model. The function does not expect any ``position_ids`` as input. + + :param model_or_path: model or loaded model + :param input_ids: input tokens + :param eos_token_ids: token representing the end of an answer + :param max_new_tokens: stops after this number of generated tokens + :param return_session: returns the instance of class + :class:`InferenceSessionForTorch + ` + created if necessary + :param transformers_config: write configuration + if missing and if this configuration is provided + :return: input tokens concatenated with new tokens + """ + import onnxruntime_genai as og + + if not isinstance(model_or_path, og.Model): + from .model_builder_helper import make_genai_config + + assert isinstance( + model_or_path, str + ), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}" + folder = os.path.dirname(model_or_path) + assert os.path.exists(folder), f"Folder {folder!r} does not exists." + assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists." + config_file = os.path.join(folder, "genai_config.json") + if not os.path.exists(config_file): + if not transformers_config: + raise FileNotFoundError( + f"Folder {model_or_path!r} does not contain 'genai_config.json'." + ) + config = make_genai_config(transformers_config, model_or_path) + with open(config_file, "w") as f: + json.dump(config, f, indent=4) + + config = og.Config(os.path.dirname(config_file)) + if input_ids.is_cuda: + config.clear_providers() + config.append_provider("cuda") + session = og.Model(config) + else: + session = model_or_path + + params = og.GeneratorParams(session) + params.set_search_options(max_new_tokens=max_new_tokens, batch_size=input_ids.shape[0]) + generator = og.Generator(session, params) + + # First call: prefill + cats = [input_ids] + generator.append_tokens(input_ids) + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + cats.append(new_token) + + input_ids = torch.cat(cats, dim=-1) + if return_session: + return input_ids, session + return input_ids diff --git a/requirements-dev.txt b/requirements-dev.txt index 5c64f6e3..818e1c0f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,6 +7,7 @@ huggingface_hub matplotlib onnx-array-api>=0.3.1 onnx +onnxruntime-genai onnxscript openpyxl packaging From 5ea455b41ccd80156568b63caf1a4dc88c2b47a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 30 Oct 2025 11:06:13 +0100 Subject: [PATCH 17/20] find names pattern --- CHANGELOGS.rst | 5 +-- .../ut_helpers/test_model_builder_helper.py | 6 ++++ _unittests/ut_helpers/test_rt_helper.py | 2 -- .../helpers/model_builder_helper.py | 36 +++++++++++++++++-- 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 6f6aeddd..8bdc0ebf 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,8 +4,9 @@ Change Logs 0.8.0 +++++ -* :pr:`276`: implements onnx_generate which implements method generate for an onnx model, - changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...) +* :pr:`278`: implements ``onnx_generate_with_genai`` +* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...) +* :pr:`276`: implements ``onnx_generate`` which implements method generate for an onnx model, * :pr:`275`: fixes function ``patched_vmap`` 0.7.16 diff --git a/_unittests/ut_helpers/test_model_builder_helper.py b/_unittests/ut_helpers/test_model_builder_helper.py index c61cabb3..23c2ef58 100644 --- a/_unittests/ut_helpers/test_model_builder_helper.py +++ b/_unittests/ut_helpers/test_model_builder_helper.py @@ -11,6 +11,7 @@ import_model_builder, create_model_builder, save_model_builder, + find_names_pattern, ) from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.helpers.rt_helper import make_feeds @@ -63,6 +64,11 @@ def test_model_builder_id(self): raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1") self.assertEqualAny(expected, got) + def test_find_names_pattern(self): + pats = ["past_key_values_key_0", "past_key_values_key_1"] + self.assertEqual("past_key_values_key_%d", find_names_pattern(pats)) + self.assertEqual("past_key_values_key_%d", find_names_pattern(pats[:1])) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index 69527d09..d4c77c33 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -138,8 +138,6 @@ def test_onnx_generate(self): model, input_ids[:1], 2, max_new_tokens=10, session=session ) self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) - print("******", res) - print("******", expected) self.assertEqual(expected.dtype, torch.int64) self.assertEqual(expected.shape, (1, 13)) self.assertEqualArray(expected, res) diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 7f315df6..71f9df9a 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -1,10 +1,11 @@ import copy import importlib.util import os +import re import requests import sys from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse from onnx import ModelProto, TensorProto, load as load_model @@ -340,6 +341,26 @@ def _post(onnx_model): return onnx_model +def find_names_pattern(names: List[str]) -> str: + """ + Finds a repeatable patterns in a list of names. + It tries to locate the figures. + + .. runpython:: + :showcode: + + from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern + pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"]) + print(pattern) + """ + patterns = [re.sub(r"(\d+)", r"%d", t) for t in names] + unique = set(patterns) + assert ( + len(unique) == 1 + ), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}" + return patterns[0] + + def make_genai_config( config, onnx_filename: str, @@ -398,8 +419,17 @@ def make_genai_config( "filename": onnx_filename, "head_size": shape[-1], "hidden_size": config.hidden_size, - "inputs": input_names, - "outputs": output_names, + "inputs": { + "input_ids": input_names[0], + "attention_mask": input_names[1], + "past_key_names": find_names_pattern(input_names[2::2]), + "past_value_names": find_names_pattern(input_names[3::2]), + }, + "outputs": { + "logits": output_names[0], + "present_key_names": find_names_pattern(output_names[1::2]), + "present_value_names": find_names_pattern(output_names[2::2]), + }, "num_attention_heads": config.num_attention_heads, "num_hidden_layers": len(past_key_values) // 2, "num_key_value_heads": shape[1], From 668ec7caca4082ee91e8a65c703743f1bc9c9880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 30 Oct 2025 15:23:31 +0100 Subject: [PATCH 18/20] add genai --- _unittests/ut_helpers/test_rt_helper.py | 2 -- .../helpers/model_builder_helper.py | 4 +-- onnx_diagnostic/helpers/rt_helper.py | 30 ++++++++++++++----- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/_unittests/ut_helpers/test_rt_helper.py b/_unittests/ut_helpers/test_rt_helper.py index 4086b95a..f165fa16 100644 --- a/_unittests/ut_helpers/test_rt_helper.py +++ b/_unittests/ut_helpers/test_rt_helper.py @@ -59,8 +59,6 @@ def test_onnx_generate(self): model, input_ids[:1], 2, max_new_tokens=10, session=session ) self.assertEqualArray(input_ids[:1], expected[:, :n_inputs]) - print("******", res) - print("******", expected) self.assertEqual(expected.dtype, torch.int64) self.assertEqual(expected.shape, (1, 13)) self.assertEqualArray(expected, res) diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 71f9df9a..e218542c 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -416,7 +416,7 @@ def make_genai_config( "log_id": "onnxruntime-genai", "provider_options": [], }, - "filename": onnx_filename, + "filename": os.path.split(onnx_filename)[-1], "head_size": shape[-1], "hidden_size": config.hidden_size, "inputs": { @@ -436,7 +436,7 @@ def make_genai_config( }, "eos_token_id": eos_token_id, "pad_token_id": pad_token_id, - # "type": self.model_type[ : self.model_type.find("For") + "type": config.model_type, # if "For" in self.model_type else len(self.model_type)].lower(), "vocab_size": config.vocab_size, }, diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 9d40e9d7..4166b904 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -285,7 +285,11 @@ def onnx_generate( import os from onnx_diagnostic.helpers import string_type, string_diff - from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate + from onnx_diagnostic.helpers.rt_helper import ( + onnx_generate, + generate_and_validate, + onnx_generate_with_genai, + ) from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.export.api import to_onnx @@ -315,11 +319,11 @@ def onnx_generate( exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder ) - print("-- onnx_generate") + print("-- generate with onnx") onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10) print("-- onnx output", onnx_outputs) - print("-- generate") + print("-- generate with pytorch") torch_outputs, diffs = generate_and_validate( model, input_ids[:1], 2, max_new_tokens=10, session=model_name ) @@ -327,6 +331,16 @@ def onnx_generate( print("-- differences at each step:") for i, d in enumerate(diffs): print(f"iteration {i}: {string_diff(d)}") + + print("-- generate with genai") + genai_outputs, session = onnx_generate_with_genai( + model_name, + input_ids[:1], + max_new_tokens=10, + return_session=True, + transformers_config=data["configuration"], + ) + print("-- genai output", genai_outputs) """ if not isinstance(model_or_path, InferenceSessionForTorch): providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else [] @@ -439,18 +453,20 @@ def onnx_generate_with_genai( session = model_or_path params = og.GeneratorParams(session) - params.set_search_options(max_new_tokens=max_new_tokens, batch_size=input_ids.shape[0]) + params.set_search_options( + max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0] + ) generator = og.Generator(session, params) # First call: prefill - cats = [input_ids] + cats = [] generator.append_tokens(input_ids) while not generator.is_done(): generator.generate_next_token() new_token = generator.get_next_tokens()[0] - cats.append(new_token) + cats.append(int(new_token)) - input_ids = torch.cat(cats, dim=-1) + input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1) if return_session: return input_ids, session return input_ids From e20f075b50e07a9900ddc115d927c95b4c621db6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 30 Oct 2025 15:24:53 +0100 Subject: [PATCH 19/20] doc --- onnx_diagnostic/helpers/rt_helper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index 4166b904..d03f3c24 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -323,6 +323,7 @@ def onnx_generate( onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10) print("-- onnx output", onnx_outputs) + # The example continues with other functions doing the same. print("-- generate with pytorch") torch_outputs, diffs = generate_and_validate( model, input_ids[:1], 2, max_new_tokens=10, session=model_name @@ -422,6 +423,9 @@ def onnx_generate_with_genai( :param transformers_config: write configuration if missing and if this configuration is provided :return: input tokens concatenated with new tokens + + See example given with function :func:`onnx_generate + `. """ import onnxruntime_genai as og From a01b39f9da73686b1ec3a530e45a1cb17cfc6c38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 30 Oct 2025 15:47:45 +0100 Subject: [PATCH 20/20] fix doc --- _doc/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/_doc/conf.py b/_doc/conf.py index 563c87a5..c2bf29f2 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -239,6 +239,7 @@ def linkcode_resolve(domain, info): "ONNX": "https://onnx.ai/", "ONNX Operators": "https://onnx.ai/onnx/operators/", "onnxruntime": "https://onnxruntime.ai/", + "onnxruntime-genai": "https://github.com/microsoft/onnxruntime-genai", "onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html", "onnxruntime kernels": "https://onnxruntime.ai/docs/reference/operators/OperatorKernels.html", "onnx-array-api": "https://sdpython.github.io/doc/onnx-array-api/dev/",