From 63a1408fcd31139ab1073e3bd93a25848b00c1de Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 16:05:12 +0200 Subject: [PATCH 1/7] Support for static cache --- _unittests/ut_helpers/test_cache_helper.py | 44 +++++- _unittests/ut_tasks/try_tasks.py | 29 ++++ _unittests/ut_torch_models/test_tiny_llms.py | 45 +++++- onnx_diagnostic/_command_lines_parser.py | 148 +++++++++--------- onnx_diagnostic/helpers/cache_helper.py | 59 +++++++ onnx_diagnostic/helpers/helper.py | 33 +++- onnx_diagnostic/helpers/torch_helper.py | 12 ++ .../tasks/automatic_speech_recognition.py | 3 + onnx_diagnostic/tasks/feature_extraction.py | 3 + onnx_diagnostic/tasks/fill_mask.py | 3 + onnx_diagnostic/tasks/image_classification.py | 3 + onnx_diagnostic/tasks/image_text_to_text.py | 3 + onnx_diagnostic/tasks/mixture_of_expert.py | 3 + onnx_diagnostic/tasks/object_detection.py | 3 + onnx_diagnostic/tasks/sentence_similarity.py | 3 + onnx_diagnostic/tasks/summarization.py | 3 + onnx_diagnostic/tasks/text2text_generation.py | 3 + onnx_diagnostic/tasks/text_classification.py | 3 + onnx_diagnostic/tasks/text_generation.py | 133 +++++++++++----- .../tasks/zero_shot_image_classification.py | 3 + .../onnx_export_serialization.py | 37 +++++ .../torch_models/untrained/llm_tiny_llm.py | 72 +++------ 22 files changed, 466 insertions(+), 182 deletions(-) diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 0dfe45e3..56ead0ae 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -2,13 +2,14 @@ import torch import transformers from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers -from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.helpers import string_type, max_diff from onnx_diagnostic.helpers.cache_helper import ( flatten_unflatten_for_dynamic_shapes, make_dynamic_cache, make_encoder_decoder_cache, make_mamba_cache, make_sliding_window_cache, + make_static_cache, ) from onnx_diagnostic.export import CoupleInputsDynamicShapes from onnx_diagnostic.torch_export_patches.patch_inputs import ( @@ -104,6 +105,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self): ] ), ) + self.assertEqual(0, max_diff(c2, c2)["abs"]) self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache) flat, _spec = torch.utils._pytree.tree_flatten(c2) self.assertIsInstance(flat, list) @@ -149,6 +151,7 @@ def test_make_mamba_cache(self): "ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])", text, ) + self.assertEqual(0, max_diff(cache, cache)["abs"]) def test_make_sliding_window_cache(self): cache = make_sliding_window_cache( @@ -164,6 +167,45 @@ def test_make_sliding_window_cache(self): "value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])", text, ) + self.assertEqual(0, max_diff(cache, cache)["abs"]) + + def test_make_static_cache(self): + cache = make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + text = self.string_type(cache, with_shape=True) + self.assertEqual( + "StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], " + "value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])", + text, + ) + self.assertEqual(0, max_diff(cache, cache)["abs"]) + + def test_unflatten_flatten_static_cache(self): + with torch_export_patches(patch_transformers=True): + c2 = make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + self.assertEqual(0, max_diff(c2, c2)["abs"]) + self.assertIsInstance(c2, transformers.cache_utils.StaticCache) + flat, _spec = torch.utils._pytree.tree_flatten(c2) + self.assertIsInstance(flat, list) + self.assertEqual(len(flat), 6) + unflat = flatten_unflatten_for_dynamic_shapes(c2) + self.assertIsInstance(unflat, list) + self.assertEqual(len(unflat), 2) + self.assertEqual( + "#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]", + self.string_type(unflat, with_shape=True), + ) if __name__ == "__main__": diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 9217da63..6d6df11f 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -98,6 +98,35 @@ def test_text2text_generation(self): ) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) + @never_test() + def test_text2text_generation_static(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text2t + + import torch + from transformers import AutoTokenizer, AutoModelForCausalLM + + tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM") + model = AutoModelForCausalLM.from_pretrained("arnir0/Tiny-LLM") + + text = "def greet(user): print(f'hello !')" + input_ids = tokenizer(text, return_tensors="pt").input_ids + mask = ( + torch.tensor([1 for i in range(input_ids.shape[1])]) + .to(torch.int64) + .reshape((1, -1)) + ) + + # simply generate a single sequence + print() + with steal_forward(model): + generated_ids = model.generate( + input_ids=input_ids, + attention_mask=mask, + max_new_tokens=117, + cache_implementation="static", + ) + print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) + @never_test() def test_text_generation_phi4_mini(self): # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index ae4f5682..3e8c7de9 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -1,32 +1,61 @@ import copy import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings from onnx_diagnostic.torch_models.llms import get_tiny_llm from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str class TestTinyLlm(ExtTestCase): - def test_get_tiny_llm(self): + def test_tiny_llm_run_dynamic(self): data = get_tiny_llm() model, inputs = data["model"], data["inputs"] self.assertIn("DynamicCache", string_type(inputs)) model(**inputs) @ignore_warnings(UserWarning) - @requires_transformers("4.53") - def test_export_tiny_llm_1(self): + def test_tiny_llm_export_dynamic(self): data = get_tiny_llm() model, inputs = data["model"], data["inputs"] expected = model(**copy.deepcopy(inputs)) self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - ep = torch.export.export( - model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, + (), + kwargs=copy.deepcopy(inputs), + dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), + ) + got = ep.module()(**inputs) + self.assertEqualArrayAny(expected, got) + + def test_tiny_llm_run_static(self): + data = get_tiny_llm(use_static_cache=True) + model, inputs = data["model"], data["inputs"] + self.assertIn("StaticCache", string_type(inputs)) + model(**inputs) + + @ignore_warnings(UserWarning) + def test_tiny_llm_export_static(self): + data = get_tiny_llm(use_static_cache=True) + model, inputs = data["model"], data["inputs"] + expected = model(**copy.deepcopy(inputs)) + self.assertEqual( + {"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs) ) - got = ep.module()(**inputs) - self.assertEqualArrayAny(expected, got) + with torch_export_patches(patch_transformers=True, stop_if_static=1): + ep = torch.export.export( + model, + (), + kwargs=copy.deepcopy(inputs), + dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), + ) + got = ep.module()(**inputs) + self.assertEqualArrayAny(expected, got) if __name__ == "__main__": diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index bbcdf47e..282de6b9 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -7,17 +7,16 @@ import onnx from typing import Any, Dict, List, Optional, Union from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction -from textwrap import dedent def get_parser_lighten() -> ArgumentParser: parser = ArgumentParser( prog="lighten", - description=dedent( + description=textwrap.dedent( + """ + Removes the weights from a heavy model, stores statistics to restore + random weights. """ - Removes the weights from a heavy model, stores statistics to restore - random weights. - """ ), epilog="This is mostly used to write unit tests without adding " "a big onnx file to the repository.", @@ -70,11 +69,11 @@ def _cmd_lighten(argv: List[Any]): def get_parser_unlighten() -> ArgumentParser: parser = ArgumentParser( prog="unlighten", - description=dedent( + description=textwrap.dedent( + """ + Restores random weights for a model reduces with command lighten, + the command expects to find a file nearby with extension '.stats'. """ - Restores random weights for a model reduces with command lighten, - the command expects to find a file nearby with extension '.stats'. - """ ), epilog="This is mostly used to write unit tests without adding " "a big onnx file to the repository.", @@ -120,11 +119,7 @@ def _cmd_unlighten(argv: List[Any]): def get_parser_print() -> ArgumentParser: parser = ArgumentParser( prog="print", - description=dedent( - """ - Prints the model on the standard output. - """ - ), + description="Prints the model on the standard output.", epilog="To show a model.", formatter_class=RawTextHelpFormatter, ) @@ -143,6 +138,7 @@ def get_parser_print() -> ArgumentParser: "\n" ) ), + formatter_class=RawTextHelpFormatter, ) parser.add_argument("input", type=str, help="onnx model to load") return parser @@ -171,11 +167,11 @@ def _cmd_print(argv: List[Any]): def get_parser_find() -> ArgumentParser: parser = ArgumentParser( prog="find", - description=dedent( + description=textwrap.dedent( + """ + Look into a model and search for a set of names, + tells which node is consuming or producing it. """ - Look into a model and search for a set of names, - tells which node is consuming or producing it. - """ ), epilog="Enables Some quick validation.", ) @@ -191,8 +187,8 @@ def get_parser_find() -> ArgumentParser: "--names", type=str, required=False, - help="names to look at comma separated values, if 'SHADOW', " - "search for shadowing names", + help="Names to look at comma separated values, if 'SHADOW', " + "search for shadowing names.", ) parser.add_argument( "-v", @@ -206,7 +202,7 @@ def get_parser_find() -> ArgumentParser: "--v2", default=False, action=BooleanOptionalAction, - help="use enumerate_results instead of onnx_find", + help="Uses enumerate_results instead of onnx_find.", ) return parser @@ -235,12 +231,13 @@ def _cmd_find(argv: List[Any]): def get_parser_config() -> ArgumentParser: parser = ArgumentParser( prog="config", - description=dedent( + description=textwrap.dedent( + """ + Prints out a configuration for a model id, + prints the associated task as well. """ - Prints out a configuration for a model id, - prints the associated task as well. - """ ), + formatter_class=RawTextHelpFormatter, epilog="", ) parser.add_argument( @@ -248,29 +245,29 @@ def get_parser_config() -> ArgumentParser: "--mid", type=str, required=True, - help="model id, usually /", + help="model id, usually `/`", ) parser.add_argument( "-t", "--task", default=False, action=BooleanOptionalAction, - help="displays the task as well", + help="Displays the task as well.", ) parser.add_argument( "-c", "--cached", default=True, action=BooleanOptionalAction, - help="uses cached configuration, only available for some of them, " - "mostly for unit test purposes", + help="Uses cached configuration, only available for some of them,\n" + "mostly for unit test purposes.", ) parser.add_argument( "--mop", metavar="KEY=VALUE", nargs="*", help="Additional model options, use to change some parameters of the model, " - "example: --mop attn_implementation=eager", + "example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager", action=_ParseDict, ) return parser @@ -329,15 +326,16 @@ def __call__(self, parser, namespace, values, option_string=None): def get_parser_validate() -> ArgumentParser: parser = ArgumentParser( - prog="test", - description=dedent( + prog="validate", + description=textwrap.dedent( + """ + Prints out dummy inputs for a particular task or a model id. + If both mid and task are empty, the command line displays the list + of supported tasks. """ - Prints out dummy inputs for a particular task or a model id. - If both mid and task are empty, the command line displays the list - of supported tasks. - """ ), epilog="If the model id is specified, one untrained version of it is instantiated.", + formatter_class=RawTextHelpFormatter, ) parser.add_argument("-m", "--mid", type=str, help="model id, usually /") parser.add_argument("-t", "--task", default=None, help="force the task to use") @@ -348,62 +346,61 @@ def get_parser_validate() -> ArgumentParser: "--run", default=False, action=BooleanOptionalAction, - help="runs the model to check it runs", + help="Runs the model to check it runs.", ) parser.add_argument( "-q", "--quiet", default=False, action=BooleanOptionalAction, - help="catches exception, report them in the summary", + help="Catches exception, reports them in the summary.", ) parser.add_argument( "--patch", default=True, action=BooleanOptionalAction, - help="applies patches before exporting", + help="Applies patches before exporting.", ) parser.add_argument( "--rewrite", default=True, action=BooleanOptionalAction, - help="applies rewrite before exporting", + help="Applies rewrite before exporting.", ) parser.add_argument( "--stop-if-static", default=0, type=int, - help="raises an exception if a dynamic dimension becomes static", + help="Raises an exception if a dynamic dimension becomes static.", ) parser.add_argument( "--trained", default=False, action=BooleanOptionalAction, - help="validate the trained model (requires downloading)", + help="Validates the trained model (requires downloading).", ) parser.add_argument( "--inputs2", default=True, action=BooleanOptionalAction, - help="if run is on, the command lines validates the model on a " - "second set of inputs to check the exported model supports dynamism", + help="Validates the model on a second set of inputs\n" + "to check the exported model supports dynamism.", ) parser.add_argument( "--runtime", choices=["onnxruntime", "torch", "ref"], default="onnxruntime", - help="onnx runtime to use, onnxruntime by default", + help="onnx runtime to use, `onnxruntime` by default", ) parser.add_argument( "-o", "--dump-folder", - help="if not empty, a folder is created to dumps statistics, " - "exported program, onnx...", + help="A folder is created to dumps statistics,\nexported program, onnx...", ) parser.add_argument( "--drop", - help="drops the following inputs names, it should be a list " - "with comma separated values", + help="Drops the following inputs names, it should be a list\n" + "with comma separated values.", ) parser.add_argument( "--opset", @@ -413,24 +410,25 @@ def get_parser_validate() -> ArgumentParser: ) parser.add_argument( "--subfolder", - help="subfolder where to find the model and the configuration", + help="Subfolder where to find the model and the configuration.", ) parser.add_argument( "--ortfusiontype", required=False, - help="applies onnxruntime fusion, this parameter should contain the " - "model type or multiple values separated by `|`. `ALL` can be used " - "to run them all", + help="Applies onnxruntime fusion, this parameter should contain the\n" + "model type or multiple values separated by `|`. `ALL` can be used\n" + "to run them all.", ) parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity") - parser.add_argument("--dtype", help="changes dtype if necessary") - parser.add_argument("--device", help="changes the device if necessary") + parser.add_argument("--dtype", help="Changes dtype if necessary.") + parser.add_argument("--device", help="Changes the device if necessary.") parser.add_argument( "--iop", metavar="KEY=VALUE", nargs="*", - help="Additional input options, use to change the default " - "inputs use to export, example: --iop cls_cache=SlidingWindowCache", + help="Additional input options, use to change the default" + "inputs use to export, example:\n --iop cls_cache=SlidingWindowCache" + "\n --iop cls_cache=StaticCache", action=_ParseDict, ) parser.add_argument( @@ -438,8 +436,8 @@ def get_parser_validate() -> ArgumentParser: metavar="KEY=VALUE", nargs="*", help="Additional model options, use to change some parameters of the model, " - "example: ``--mop attn_implementation=eager`` or " - "``--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"``", + "example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n " + "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"", action=_ParseDict, ) parser.add_argument( @@ -519,11 +517,7 @@ def _cmd_validate(argv: List[Any]): def get_parser_stats() -> ArgumentParser: parser = ArgumentParser( prog="stats", - description=dedent( - """ - Prints out statistics on an ONNX model. - """ - ), + description="Prints out statistics on an ONNX model.", epilog="", ) parser.add_argument( @@ -570,8 +564,8 @@ def get_parser_stats() -> ArgumentParser: required=False, default="", type=str, - help="keeps only tensors whose name verifies " - "this regular expression, empty = no filter", + help="Keeps only tensors whose name verifies " + "this regular expression, empty = no filter.", ) return parser @@ -623,17 +617,17 @@ def get_main_parser() -> ArgumentParser: formatter_class=RawTextHelpFormatter, epilog=textwrap.dedent( """ - Type 'python -m onnx_diagnostic --help' - to get help for a specific command. - - config - prints a configuration for a model id - find - find node consuming or producing a result - lighten - makes an onnx model lighter by removing the weights, - unlighten - restores an onnx model produces by the previous experiment - print - prints the model on standard output - validate - validate a model - stats - produces statistics on a model - """ + Type 'python -m onnx_diagnostic --help' + to get help for a specific command. + + config - prints a configuration for a model id + find - find node consuming or producing a result + lighten - makes an onnx model lighter by removing the weights, + unlighten - restores an onnx model produces by the previous experiment + print - prints the model on standard output + validate - validate a model + stats - produces statistics on a model + """ ), ) parser.add_argument( diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 3f50ecf7..e37a1e2d 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -141,6 +141,65 @@ def make_dynamic_cache( return cache +def make_static_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], +) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.StaticCache`. + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.StaticCache` + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.helpers import string_type + from onnx_diagnostic.helpers.cache_helper import make_static_cache + + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + + past_key_values = make_static_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(n_layers) + ] + ) + print(string_type(past_key_values, with_shape=True)) + """ + + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + + cache = transformers.cache_utils.StaticCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + max_cache_len=key_value_pairs[0][0].shape[2], + ) + for i in range(len(key_value_pairs)): + assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, ( + f"Shape mismatch, expected {cache.key_cache[i].shape}, " + f"got {key_value_pairs[i][0].shape}" + ) + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, ( + f"Shape mismatch, expected {cache.value_cache[i].shape}, " + f"got {key_value_pairs[i][1].shape}" + ) + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache + + def make_encoder_decoder_cache( self_attention_cache: transformers.cache_utils.DynamicCache, cross_attention_cache: transformers.cache_utils.DynamicCache, diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 7151ef3d..7242806a 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -558,7 +558,7 @@ def string_type( print(f"[string_type] CACHE1:{type(obj)}") return f"MambaCache(conv_states={c}, ssm_states={d})" - if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"): + if obj.__class__.__name__ in {"DynamicCache", "SlidingWindowCache", "StaticCache"}: kc = string_type( obj.key_cache, with_shape=with_shape, @@ -857,7 +857,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any: return flatten_object(list(x.values()), drop_keys=drop_keys) return flatten_object(list(x.items()), drop_keys=drop_keys) - if x.__class__.__name__ == "DynamicCache": + if x.__class__.__name__ in {"DynamicCache", "StaticCache"}: res = flatten_object(x.key_cache) + flatten_object(x.value_cache) return tuple(res) if x.__class__.__name__ == "EncoderDecoderCache": @@ -1424,10 +1424,37 @@ def max_diff( f"level={level}" ) + if expected.__class__.__name__ == "StaticCache": + if got.__class__.__name__ == "StaticCache": + if verbose >= 6: + print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}") + return max_diff( + [expected.key_cache, expected.value_cache], + [got.key_cache, got.value_cache], + verbose=verbose, + hist=hist, + ) + if isinstance(got, tuple) and len(got) == 2: + return max_diff( + [expected.key_cache, expected.value_cache], + [got[0], got[1]], + debug_info=_debug(expected.__class__.__name__), + **_dkws, + ) + raise AssertionError( + f"StaticCache not fully implemented with classes " + f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " + f"and expected={string_type(expected)}, got={string_type(got)},\n" + f"level={level}" + ) + if expected.__class__.__name__ == "SlidingWindowCache": if got.__class__.__name__ == "SlidingWindowCache": if verbose >= 6: - print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}") + print( + f"[max_diff] SlidingWindowCache: " + f"{string_type(expected)} ? {string_type(got)}" + ) return max_diff( [expected.key_cache, expected.value_cache], [got.key_cache, got.value_cache], diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 83ca95a4..d4437fc4 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -16,6 +16,7 @@ make_encoder_decoder_cache, make_sliding_window_cache, make_mamba_cache, + make_static_cache, ) from .mini_onnx_builder import create_onnx_model_from_input_tensors from .onnx_helper import ( @@ -727,6 +728,15 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: ) ) ) + if value.__class__.__name__ == "StaticCache": + return make_static_cache( + list( + zip( + [t.to(to_value) for t in value.key_cache], + [t.to(to_value) for t in value.value_cache], + ) + ) + ) if value.__class__.__name__ == "EncoderDecoderCache": return make_encoder_decoder_cache( to_any(value.self_attention_cache, to_value), @@ -773,6 +783,8 @@ def torch_deepcopy(value: Any) -> Any: return make_dynamic_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) ) + if value.__class__.__name__ == "StaticCache": + return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache)))) if value.__class__.__name__ == "SlidingWindowCache": return make_sliding_window_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) diff --git a/onnx_diagnostic/tasks/automatic_speech_recognition.py b/onnx_diagnostic/tasks/automatic_speech_recognition.py index 346c3fa2..f1b4ae6b 100644 --- a/onnx_diagnostic/tasks/automatic_speech_recognition.py +++ b/onnx_diagnostic/tasks/automatic_speech_recognition.py @@ -69,6 +69,9 @@ def get_inputs( use_cache:bool,return_dict:bool ) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" diff --git a/onnx_diagnostic/tasks/feature_extraction.py b/onnx_diagnostic/tasks/feature_extraction.py index 9ef52058..4bac2aed 100644 --- a/onnx_diagnostic/tasks/feature_extraction.py +++ b/onnx_diagnostic/tasks/feature_extraction.py @@ -35,6 +35,9 @@ def get_inputs( token_type_ids:T7s1x13[0,0:A0.0], attention_mask:T7s1x13[1,1:A1.0]) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "sequence_length" shapes = { diff --git a/onnx_diagnostic/tasks/fill_mask.py b/onnx_diagnostic/tasks/fill_mask.py index 14020e1e..63a05811 100644 --- a/onnx_diagnostic/tasks/fill_mask.py +++ b/onnx_diagnostic/tasks/fill_mask.py @@ -35,6 +35,9 @@ def get_inputs( token_type_ids:T7s1x13[0,0:A0.0], attention_mask:T7s1x13[1,1:A1.0]) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "sequence_length" shapes = { diff --git a/onnx_diagnostic/tasks/image_classification.py b/onnx_diagnostic/tasks/image_classification.py index 88d9c134..cc14e4a3 100644 --- a/onnx_diagnostic/tasks/image_classification.py +++ b/onnx_diagnostic/tasks/image_classification.py @@ -48,6 +48,9 @@ def get_inputs( :param input_height: input height :return: dictionary """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." assert isinstance( input_width, int ), f"Unexpected type for input_width {type(input_width)}{config}" diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 50014621..1ae22537 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -52,6 +52,9 @@ def get_inputs( :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :return: dictionary """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) diff --git a/onnx_diagnostic/tasks/mixture_of_expert.py b/onnx_diagnostic/tasks/mixture_of_expert.py index b7e5af37..be6b7828 100644 --- a/onnx_diagnostic/tasks/mixture_of_expert.py +++ b/onnx_diagnostic/tasks/mixture_of_expert.py @@ -61,6 +61,9 @@ def get_inputs( :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :return: dictionary """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." assert not add_second_input, "add_second_input=True not yet implemented" raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.") diff --git a/onnx_diagnostic/tasks/object_detection.py b/onnx_diagnostic/tasks/object_detection.py index 2b2ec61d..d8ce8073 100644 --- a/onnx_diagnostic/tasks/object_detection.py +++ b/onnx_diagnostic/tasks/object_detection.py @@ -41,6 +41,9 @@ def get_inputs( :param input_height: input height :return: dictionary """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." assert isinstance( input_width, int ), f"Unexpected type for input_width {type(input_width)}{config}" diff --git a/onnx_diagnostic/tasks/sentence_similarity.py b/onnx_diagnostic/tasks/sentence_similarity.py index 808ae039..4e304c47 100644 --- a/onnx_diagnostic/tasks/sentence_similarity.py +++ b/onnx_diagnostic/tasks/sentence_similarity.py @@ -35,6 +35,9 @@ def get_inputs( token_type_ids:T7s1x13[0,0:A0.0], attention_mask:T7s1x13[1,1:A1.0]) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" shapes = { diff --git a/onnx_diagnostic/tasks/summarization.py b/onnx_diagnostic/tasks/summarization.py index ca70e751..3b2231a1 100644 --- a/onnx_diagnostic/tasks/summarization.py +++ b/onnx_diagnostic/tasks/summarization.py @@ -62,6 +62,9 @@ def get_inputs( decoder_input_ids:T7s1x1, encoder_outputs:dict(last_hidden_state:T1s1x16x512) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096) diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index 85ddcbd7..6dd0e3b6 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -64,6 +64,9 @@ def get_inputs( decoder_input_ids:T7s1x1, encoder_outputs:dict(last_hidden_state:T1s1x16x512) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096) diff --git a/onnx_diagnostic/tasks/text_classification.py b/onnx_diagnostic/tasks/text_classification.py index aaaa8838..e3a1d727 100644 --- a/onnx_diagnostic/tasks/text_classification.py +++ b/onnx_diagnostic/tasks/text_classification.py @@ -35,6 +35,9 @@ def get_inputs( token_type_ids:T7s1x13[0,0:A0.0], attention_mask:T7s1x13[1,1:A1.0]) """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024) shapes = { diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index d4ff57ff..6fbc7297 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -5,6 +5,7 @@ make_dynamic_cache, make_mamba_cache, make_sliding_window_cache, + make_static_cache, ) from ..helpers.config_helper import update_config, check_hasattr, _pick @@ -151,52 +152,98 @@ def get_inputs( assert config, "head_dim is None, the value cannot be set without a configuration" head_dim = config.hidden_size // config.num_attention_heads - shapes = { - "input_ids": {0: batch, 1: seq_length}, - "attention_mask": { - 0: batch, - 1: "cache+seq", # cache_length + seq_length - }, - "position_ids": { - 0: batch, - 1: "cache+seq", # cache_length + seq_length - }, - "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - ], + cache_name = ( + cls_cache + if cls_cache is None or isinstance(cls_cache, str) + else cls_cache.__name__ + ) + make_caches = { + "DynamicCache": make_dynamic_cache, + "SlidingWindowCache": make_sliding_window_cache, + "StaticCache": make_static_cache, } - - make_cache = ( - make_sliding_window_cache - if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache) - else make_dynamic_cache + assert cache_name is None or cache_name in make_caches, ( + f"Unable to handle cls_cache={cache_name!r}, it should be in " + f"{sorted(make_caches)}" ) + make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] + is_static = cache_name == "StaticCache" - inputs = dict( - input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to( - torch.int64 - ), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( - torch.int64 - ), - position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) - .to(torch.int64) - .expand((batch_size, -1)), - past_key_values=make_cache( - [ - ( - torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim - ), - torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim - ), - ) - for i in range(num_hidden_layers) - ] - ), - ) + if is_static: + # static + shapes = { + "input_ids": {0: batch, 1: seq_length}, + "attention_mask": {0: batch, 2: "seq"}, + "cache_position": {1: "seq"}, + "past_key_values": [ + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + ], + } + inputs = dict( + input_ids=torch.randint( + 0, dummy_max_token_id, (batch_size, sequence_length2) + ).to(torch.int64), + attention_mask=torch.ones( + (batch_size, num_key_value_heads, sequence_length2, head_dim) + ).to(torch.bool), + cache_position=torch.arange(sequence_length2).to(torch.int64), + past_key_values=make_cache( + [ + ( + torch.randn( + batch_size, num_key_value_heads, sequence_length, head_dim + ), + torch.randn( + batch_size, num_key_value_heads, sequence_length, head_dim + ), + ) + for i in range(num_hidden_layers) + ] + ), + ) + else: + # dynamic + shapes = { + "input_ids": {0: batch, 1: seq_length}, + "attention_mask": { + 0: batch, + 1: "cache+seq", # cache_length + seq_length + }, + "position_ids": { + 0: batch, + 1: "cache+seq", # cache_length + seq_length + }, + "past_key_values": [ + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + ], + } + + inputs = dict( + input_ids=torch.randint( + 0, dummy_max_token_id, (batch_size, sequence_length2) + ).to(torch.int64), + attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( + torch.int64 + ), + position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) + .to(torch.int64) + .expand((batch_size, -1)), + past_key_values=make_cache( + [ + ( + torch.randn( + batch_size, num_key_value_heads, sequence_length, head_dim + ), + torch.randn( + batch_size, num_key_value_heads, sequence_length, head_dim + ), + ) + for i in range(num_hidden_layers) + ] + ), + ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: res["inputs2"] = get_inputs( diff --git a/onnx_diagnostic/tasks/zero_shot_image_classification.py b/onnx_diagnostic/tasks/zero_shot_image_classification.py index a341a191..83163552 100644 --- a/onnx_diagnostic/tasks/zero_shot_image_classification.py +++ b/onnx_diagnostic/tasks/zero_shot_image_classification.py @@ -55,6 +55,9 @@ def get_inputs( # attention_mask:T7s2x7 # pixel_values:T1s2x3x224x224 """ + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." assert isinstance( input_width, int ), f"Unexpected type for input_width {type(input_width)}{config}" diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 10589ff9..767f2739 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -9,9 +9,11 @@ MambaCache, EncoderDecoderCache, SlidingWindowCache, + StaticCache, ) from transformers.modeling_outputs import BaseModelOutput from ..helpers import string_type +from ..helpers.cache_helper import make_static_cache PATCH_OF_PATCHES: Set[Any] = set() @@ -175,6 +177,13 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]] flatten_with_keys_sliding_window_cache, verbose=verbose, ), + StaticCache=register_class_serialization( + StaticCache, + flatten_static_cache, + unflatten_static_cache, + flatten_with_keys_static_cache, + verbose=verbose, + ), ) @@ -309,6 +318,34 @@ def unflatten_dynamic_cache( return cache +############## +# DynamicCache +############## + + +def flatten_static_cache( + cache: StaticCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_static_cache( + cache: StaticCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects.""" + values, context = flatten_static_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_static_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> StaticCache: + """Restores a :class:`transformers.cache_utils.StaticCache` from python objects.""" + return make_static_cache(list(zip(values[0], values[1]))) + + #################### # SlidingWindowCache #################### diff --git a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py index 0f9d660e..4aef5b55 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +++ b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py @@ -1,7 +1,5 @@ from typing import Any, Dict -import torch import transformers -from ...helpers.cache_helper import make_dynamic_cache def get_tiny_llm( @@ -9,6 +7,7 @@ def get_tiny_llm( sequence_length: int = 30, sequence_length2: int = 3, dynamic_rope: bool = False, + use_static_cache: bool = False, **kwargs, ) -> Dict[str, Any]: """ @@ -18,11 +17,14 @@ def get_tiny_llm( :param sequence_length: sequence length :param sequence_length2: new sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) + :param use_static_cache: use StaticCache instead of DynamicCache :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` :return: dictionary See :ref:`l-plot-tiny-llm-export` or :ref:`l-plot-tiny-llm-export-patched` for examples. """ + from ...tasks.text_generation import get_inputs + config = { "architectures": ["LlamaForCausalLM"], "bos_token_id": 1, @@ -48,56 +50,26 @@ def get_tiny_llm( config.update(**kwargs) conf = transformers.LlamaConfig(**config) + conf.cache_implementation = "static" model = transformers.LlamaForCausalLM(conf) model.eval() - # now the inputs - cache_last_dim = 96 - max_token_id = config["vocab_size"] - 1 - n_layers = config["num_hidden_layers"] - num_key_value_heads = config["num_key_value_heads"] - - batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("seq_length", min=1, max=8192) - cache_length = torch.export.Dim("cache_length", min=1, max=8192) + res = get_inputs( + model, + conf, + dummy_max_token_id=config["vocab_size"], + num_hidden_layers=config["num_hidden_layers"], + batch_size=batch_size, + sequence_length=sequence_length, + sequence_length2=sequence_length2, + dynamic_rope=dynamic_rope, + num_key_value_heads=config["num_key_value_heads"], + cls_cache="StaticCache" if use_static_cache else "DynamicCache", + ) - shapes = { - "input_ids": {0: batch, 1: seq_length}, - "attention_mask": { - 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length - }, - "position_ids": { - 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length - }, - "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(n_layers)], - [{0: batch, 2: cache_length} for _ in range(n_layers)], - ], - } - inputs = dict( - input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to( - torch.int64 - ), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( - torch.int64 - ), - position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) - .to(torch.int64) - .expand((batch_size, -1)), - past_key_values=make_dynamic_cache( - [ - ( - torch.randn( - batch_size, num_key_value_heads, sequence_length, cache_last_dim - ), - torch.randn( - batch_size, num_key_value_heads, sequence_length, cache_last_dim - ), - ) - for i in range(n_layers) - ] - ), + return dict( + inputs=res["inputs"], + model=model, + dynamic_shapes=res["dynamic_shapes"], + configuration=conf, ) - return dict(inputs=inputs, model=model, dynamic_shapes=shapes, configuration=conf) From e1a8e1cad590a28d1fd9b8ac2e642614962a2b7a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 16:19:15 +0200 Subject: [PATCH 2/7] fix cache --- onnx_diagnostic/ext_test_case.py | 2 +- onnx_diagnostic/tasks/text_generation.py | 2 +- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index e6340a70..71e322fa 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -1014,7 +1014,7 @@ def assertEqualArrayAny( msg_ = "\n".join(excs) msg = f"{msg}\n{msg_}" if msg else msg_ raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}") - elif expected.__class__.__name__ == "DynamicCache": + elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"): atts = {"key_cache", "value_cache"} self.assertEqualArrayAny( {k: expected.__dict__.get(k, None) for k in atts}, diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 6fbc7297..aee5f2bb 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -174,7 +174,7 @@ def get_inputs( shapes = { "input_ids": {0: batch, 1: seq_length}, "attention_mask": {0: batch, 2: "seq"}, - "cache_position": {1: "seq"}, + "cache_position": {0: "seq"}, "past_key_values": [ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], diff --git a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py index 4aef5b55..59c0398d 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +++ b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py @@ -57,13 +57,13 @@ def get_tiny_llm( res = get_inputs( model, conf, - dummy_max_token_id=config["vocab_size"], - num_hidden_layers=config["num_hidden_layers"], + dummy_max_token_id=config["vocab_size"], # type: ignore[arg-type] + num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type] batch_size=batch_size, sequence_length=sequence_length, sequence_length2=sequence_length2, dynamic_rope=dynamic_rope, - num_key_value_heads=config["num_key_value_heads"], + num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type] cls_cache="StaticCache" if use_static_cache else "DynamicCache", ) From fb1844be8cf3e3a745e8c3bb8e597479cd0bca25 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 16:28:42 +0200 Subject: [PATCH 3/7] fix ut --- _doc/examples/plot_export_tiny_llm.py | 13 +++++++++++-- .../ut_torch_models/test_tiny_llms_bypassed.py | 10 ++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm.py b/_doc/examples/plot_export_tiny_llm.py index 06c90384..64b24085 100644 --- a/_doc/examples/plot_export_tiny_llm.py +++ b/_doc/examples/plot_export_tiny_llm.py @@ -33,6 +33,7 @@ from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.torch_helper import steal_forward from onnx_diagnostic.torch_models.llms import get_tiny_llm +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str MODEL_NAME = "arnir0/Tiny-LLM" @@ -131,7 +132,11 @@ def _forward_(*args, _f=None, **kwargs): try: ep = torch.export.export( - untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False + untrained_model, + (), + kwargs=cloned_inputs, + dynamic_shapes=use_dyn_not_str(dynamic_shapes), + strict=False, ) print("It worked:") print(ep) @@ -166,7 +171,11 @@ def _forward_(*args, _f=None, **kwargs): try: ep = torch.export.export( - model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False + model, + (), + kwargs=cloned_inputs, + dynamic_shapes=use_dyn_not_str(dynamic_shapes), + strict=False, ) print("It worked:") print(ep) diff --git a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py index 746ad78c..4d0d7b0c 100644 --- a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py +++ b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py @@ -2,12 +2,13 @@ import unittest import torch from transformers.cache_utils import DynamicCache -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout from onnx_diagnostic.torch_models.llms import get_tiny_llm from onnx_diagnostic.torch_models.llms import get_phi2 from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.torch_helper import torch_deepcopy from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched_DynamicCache, ) @@ -15,6 +16,7 @@ class TestTinyLlmBypassed(ExtTestCase): @ignore_warnings(UserWarning) + @hide_stdout() def test_export_tiny_llm_2_bypassed(self): data = get_tiny_llm() model, inputs = data["model"], data["inputs"] @@ -50,7 +52,11 @@ def debug(): debug() ep = torch.export.export( - model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False + model, + (), + kwargs=inputs, + dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), + strict=False, ) got = ep.module()(**inputs) self.assertEqualArrayAny(expected, got) From 4d2ca8818639b27ef8aa62ecec981dbad2963574 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 16:37:29 +0200 Subject: [PATCH 4/7] fix static --- _unittests/ut_torch_models/test_tiny_llms.py | 4 +++- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index 3e8c7de9..8d636b26 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -1,7 +1,7 @@ import copy import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers from onnx_diagnostic.torch_models.llms import get_tiny_llm from onnx_diagnostic.helpers import string_type from onnx_diagnostic.torch_export_patches import torch_export_patches @@ -33,6 +33,7 @@ def test_tiny_llm_export_dynamic(self): got = ep.module()(**inputs) self.assertEqualArrayAny(expected, got) + @requires_transformers("4.52") def test_tiny_llm_run_static(self): data = get_tiny_llm(use_static_cache=True) model, inputs = data["model"], data["inputs"] @@ -40,6 +41,7 @@ def test_tiny_llm_run_static(self): model(**inputs) @ignore_warnings(UserWarning) + @requires_transformers("4.52") def test_tiny_llm_export_static(self): data = get_tiny_llm(use_static_cache=True) model, inputs = data["model"], data["inputs"] diff --git a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py index 59c0398d..f8b7fe63 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +++ b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py @@ -50,7 +50,8 @@ def get_tiny_llm( config.update(**kwargs) conf = transformers.LlamaConfig(**config) - conf.cache_implementation = "static" + if use_static_cache: + conf.cache_implementation = "static" model = transformers.LlamaForCausalLM(conf) model.eval() From 77a98936bab9f60dc955234f660be1bf1685a233 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 16:41:23 +0200 Subject: [PATCH 5/7] fix issues --- _doc/examples/plot_export_tiny_llm_patched.py | 5 +++-- onnx_diagnostic/_command_lines_parser.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm_patched.py b/_doc/examples/plot_export_tiny_llm_patched.py index 5ed9566e..8790de82 100644 --- a/_doc/examples/plot_export_tiny_llm_patched.py +++ b/_doc/examples/plot_export_tiny_llm_patched.py @@ -70,6 +70,7 @@ from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered from onnx_diagnostic.helpers import string_type from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_models.llms import get_tiny_llm @@ -110,7 +111,7 @@ untrained_model, (), kwargs=modificator(cloned_inputs), - dynamic_shapes=dynamic_shapes, + dynamic_shapes=use_dyn_not_str(dynamic_shapes), strict=False, # mandatory for torch==2.6 ) print("It worked:") @@ -131,7 +132,7 @@ model, (), kwargs=modificator(cloned_inputs), - dynamic_shapes=dynamic_shapes, + dynamic_shapes=use_dyn_not_str(dynamic_shapes), strict=False, # mandatory for torch==2.6 ) print("It worked:") diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 282de6b9..fa9df1fc 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -138,7 +138,6 @@ def get_parser_print() -> ArgumentParser: "\n" ) ), - formatter_class=RawTextHelpFormatter, ) parser.add_argument("input", type=str, help="onnx model to load") return parser From 74a41304d48c72209b2afb7ee8bf4c68d5c8ca7a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 18:20:36 +0200 Subject: [PATCH 6/7] fix --- _doc/conf.py | 1 + _unittests/ut_torch_models/test_tiny_llms.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/_doc/conf.py b/_doc/conf.py index 1e989447..72a4fd94 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -136,6 +136,7 @@ def linkcode_resolve(domain, info): ("py:class", "transformers.cache_utils.EncoderDecoderCache"), ("py:class", "transformers.cache_utils.MambaCache"), ("py:class", "transformers.cache_utils.SlidingWindowCache"), + ("py:class", "transformers.cache_utils.StaticCache"), ("py:class", "transformers.configuration_utils.PretrainedConfig"), ("py:class", "transformers.modeling_outputs.BaseModelOutput"), ("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"), diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index 8d636b26..5e8ec565 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -1,7 +1,12 @@ import copy import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + ignore_warnings, + requires_transformers, + requires_torch, +) from onnx_diagnostic.torch_models.llms import get_tiny_llm from onnx_diagnostic.helpers import string_type from onnx_diagnostic.torch_export_patches import torch_export_patches @@ -16,6 +21,7 @@ def test_tiny_llm_run_dynamic(self): model(**inputs) @ignore_warnings(UserWarning) + @requires_torch("2.8") def test_tiny_llm_export_dynamic(self): data = get_tiny_llm() model, inputs = data["model"], data["inputs"] From ca20fb5f28ff5084242c26c434a2453d2e5488d6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 16 Jun 2025 18:31:31 +0200 Subject: [PATCH 7/7] fix --- _unittests/ut_torch_models/test_tiny_llms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index 5e8ec565..2ac4dfb3 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -48,6 +48,7 @@ def test_tiny_llm_run_static(self): @ignore_warnings(UserWarning) @requires_transformers("4.52") + @requires_torch("2.8") def test_tiny_llm_export_static(self): data = get_tiny_llm(use_static_cache=True) model, inputs = data["model"], data["inputs"]