From 6bca0b4af4b4b9232499dd22b113d402964e447c Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 16:54:49 +0200 Subject: [PATCH 01/17] patch for _compute_dynamic_ntk_parameters --- .../ut_torch_models/test_validate_models.py | 2 +- .../test_validate_whole_models.py | 41 +++- onnx_diagnostic/helpers/config_helper.py | 12 +- .../onnx_export_errors.py | 114 +++++++--- .../patches/patch_transformers.py | 206 ++++++++++++++++-- .../hghub/hub_data_cached_configs.py | 40 ++++ 6 files changed, 354 insertions(+), 61 deletions(-) diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index 43328545..d6b0b07d 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -28,7 +28,7 @@ def test_validate_microsoft_phi4_reasoning(self): patch=True, rewrite=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, - dump_folder="dump_test_validate_model_custom", + dump_folder="dump_test/validate_microsoft_phi4_reasoning", ) self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5) self.assertIn("onnx_filename", data) diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index c2bb69ca..ff5d9da8 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -1,6 +1,7 @@ import copy import unittest import packaging.version as pv +import onnx import torch from onnx_diagnostic.ext_test_case import ( ExtTestCase, @@ -63,7 +64,7 @@ def test_validate_model_export(self): do_run=True, verbose=10, exporter="export-nostrict", - dump_folder="dump_test_validate_model_export", + dump_folder="dump_test/validate_model_export", patch=True, ) self.assertIsInstance(summary, dict) @@ -79,7 +80,7 @@ def test_validate_model_onnx_dynamo_ir(self): do_run=True, verbose=10, exporter="onnx-dynamo", - dump_folder="dump_test_validate_model_onnx_dynamo", + dump_folder="dump_test/validate_model_onnx_dynamo_ir", patch=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, optimization="ir", @@ -104,7 +105,7 @@ def test_validate_model_onnx_dynamo_os_ort(self): do_run=True, verbose=10, exporter="onnx-dynamo", - dump_folder="dump_test_validate_model_onnx_dynamo", + dump_folder="dump_test/validate_model_onnx_dynamo_os_ort", patch=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, optimization="os_ort", @@ -126,7 +127,7 @@ def test_validate_model_custom_os_ort(self): do_run=True, verbose=10, exporter="custom", - dump_folder="dump_validate_model_custom_os_ort", + dump_folder="dump_test/validate_model_custom_os_ort", patch=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, optimization="default+os_ort", @@ -148,7 +149,7 @@ def test_validate_model_custom(self): do_run=True, verbose=10, exporter="custom", - dump_folder="dump_test_validate_model_custom", + dump_folder="dump_test/validate_model_custom_tiny_llm", patch=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, optimization="default", @@ -177,7 +178,7 @@ def test_validate_model_custom_torch(self): do_run=True, verbose=10, exporter="custom-noinline", - dump_folder="dump_test_validate_model_custom_torch", + dump_folder="dump_test/validate_model_custom_torch", patch=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, optimization="default", @@ -221,7 +222,7 @@ def test_validate_model_modelbuilder(self): do_run=True, verbose=10, exporter="modelbuilder", - dump_folder="dump_test_validate_model_modelbuilder", + dump_folder="dump_test/validate_model_modelbuilder", ) self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) @@ -240,7 +241,7 @@ def test_validate_model_vit_model(self): do_run=True, verbose=10, exporter="onnx-dynamo", - dump_folder="dump_test_validate_model_onnx_dynamo", + dump_folder="dump_test/validate_model_vit_model", inputs2=True, ) self.assertIsInstance(summary, dict) @@ -254,6 +255,30 @@ def test_validate_model_vit_model(self): onnx_filename = data["onnx_filename"] self.assertExists(onnx_filename) + @requires_torch("2.7") + @hide_stdout() + @ignore_warnings(FutureWarning) + @requires_transformers("4.51") + def test_validate_phi3_mini_4k_instruct(self): + mid = "microsoft/Phi-3-mini-4k-instruct" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="custom", + dump_folder="dump_test/validate_phi3_mini_4k_instruct", + inputs2=True, + patch=True, + rewrite=True, + model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}}, + ) + self.assertIsInstance(summary, dict) + self.assertIsInstance(data, dict) + onnx_filename = data["onnx_filename"] + onx = onnx.load(onnx_filename) + op_types = set(n.op_type for n in onx.graph.node) + self.assertIn("If", op_types) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index 4b8ac43f..c22340ab 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]): config._attn_implementation_autoset = False continue if isinstance(v, dict): - assert hasattr( - config, k - ), f"missing attribute {k!r} in config={config}, cannot update it with {v}" - update_config(getattr(config, k), v) + if not hasattr(config, k) or getattr(config, k) is None: + setattr(config, k, v) + continue + existing = getattr(config, k) + if type(existing) is dict: + existing.update(v) + else: + update_config(getattr(config, k), v) continue setattr(config, k, v) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index ac7e66fa..b11f5f94 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -1,5 +1,8 @@ +import functools +import importlib import contextlib -from typing import Any, Callable, Dict, List, Optional +import re +from typing import Any, Callable, Dict, List, Optional, Tuple from .onnx_export_serialization import ( register_cache_serialization, unregister_cache_serialization, @@ -7,18 +10,20 @@ from .patches import patch_transformers as patch_transformers_list -def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: +def get_function(name: str) -> Tuple["module", "function"]: # noqa: F821 """ - Applies all patches defined in classes prefixed by ``patched_`` - ``cls._PATCHED_CLASS_`` defines the class to patch, - ``cls._PATCHES_`` defines the method to patch. - The returns information needs to be sent to :func:`unpatch_module_or_classes` - to revert the changes. - - :param mod: module of list of clsses to patch - :param verbose: verbosity - :return: patch info + Returns the module and the function based on its name. """ + spl = name.split(".") + module_name = ".".join(spl[:-1]) + fname = spl[-1] + mod = importlib.import_module(module_name) + return mod, getattr(mod, fname) + + +@functools.lru_cache +def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]: + """Returns the list of patches to make for a specific module.""" if isinstance(mod, list): to_patch = mod name = "list" @@ -29,10 +34,50 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call v = getattr(mod, k) if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): to_patch.append(v) + else: + # a function + doc = v.__doc__ + if doc.startswith("manual patch"): + continue + reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]") + fall = reg.findall(doc) + assert ( + len(fall) == 1 + ), f"Unable to find patching information for {v} in \n{doc}" + fmod, f = get_function(fall[0]) + to_patch.append({"module": fmod, "function": f, "patch": v}) + name = mod.__name__ + return name, to_patch + + +def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: + """ + Applies all patches defined in classes prefixed by ``patched_`` + ``cls._PATCHED_CLASS_`` defines the class to patch, + ``cls._PATCHES_`` defines the method to patch. + The returns information needs to be sent to :func:`unpatch_module_or_classes` + to revert the changes. + + :param mod: module of list of clsses to patch + :param verbose: verbosity + :return: patch info + """ + name, to_patch = get_patches(mod, verbose) res = {} for cls in to_patch: + if isinstance(cls, dict): + # a function + keep = {} + original = cls["module"] + f = cls["function"] + res[f] = f + if verbose: + print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}") + setattr(original, f.__name__, cls["patch"]) + continue + original = cls._PATCHED_CLASS_ methods = cls._PATCHES_ if verbose: @@ -53,30 +98,35 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo :param mod: module of list of clsses to patch :param verbose: verbosity """ - if isinstance(mod, list): - to_patch = mod - name = "list" - else: - to_patch = [] - for k in dir(mod): - if k.startswith("patched_"): - v = getattr(mod, k) - if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): - to_patch.append(v) - name = mod.__name__ - set_patch = set(to_patch) + name, to_patch = get_patches(mod, verbose) + set_patch_cls = {i for i in to_patch if not isinstance(i, dict)} + dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)} for cls, methods in info.items(): - assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})" + if cls in set_patch_cls: + if verbose: + print( + f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}" + ) + original = cls._PATCHED_CLASS_ + for n, v in methods.items(): + if v is None: + # The method did not exist. We remove it. + delattr(original, n) + else: + setattr(original, n, v) + continue + assert cls in dict_patch_fct, ( + f"No patch registered for {cls} in {mod} " + f"(found {set_patch_cls} and {set(dict_patch_fct)})" + ) + patch = dict_patch_fct[cls] if verbose: - print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}") - original = cls._PATCHED_CLASS_ - for n, v in methods.items(): - if v is None: - # The method did not exist. We remove it. - delattr(original, n) - else: - setattr(original, n, v) + print( + f"[unpatch_module_or_classes] function " + f"{patch['module'].__name__}.{cls.__name__}" + ) + setattr(patch["module"], cls.__name__, patch["function"]) @contextlib.contextmanager diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 0f75c7e8..1be4be0a 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -11,7 +11,7 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: - """Patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" + """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" from ...helpers import string_type dimensions: List[Tuple[Optional[int], ...]] = [ @@ -534,9 +534,145 @@ def prepare_inputs_for_generation( return model_inputs -def patched_dynamic_rope_update(rope_forward): +def patched__compute_dynamic_ntk_parameters( + config: Optional[transformers.PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]`` + + Computes the inverse frequencies with NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, + used to update the dynamic RoPE at inference time. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous + RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), + containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the + omputed cos/sin (unused in this type of RoPE). """ - patch:transformers.modeling_rope_utils.dynamic_rope_update + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_dynamic_ntk_parameters`, got " + f"`rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + max_position_embeddings = rope_kwargs["max_position_embeddings"] + factor = rope_kwargs["factor"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = ( + config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + ) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + # seq_len = seq_len if seq_len is not None and + # seq_len > max_position_embeddings else max_position_embeddings + if seq_len is None: + seq_len = max_position_embeddings + else: + torch._check(isinstance(seq_len, torch.Tensor)) + seq_len = torch.max( + seq_len, torch.Tensor(max_position_embeddings, dtype=seq_len.dtype) + ) + + # Compute the inverse frequencies + base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** ( + dim / (dim - 2) + ) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) + / dim + ) + ) + return inv_freq, attention_factor + + +def patched_dynamic_rope_update(rope_forward): + """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]`` + + ``rope_type`` is determined in the constructor of class + :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`. + + .. code-block:: python + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + + The original code of the patched function: + + .. code-block:: python + + def dynamic_rope_update(rope_forward): + def longrope_frequency_update(self, position_ids, device): + seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = + self.config.original_max_position_embeddings + else: + original_max_position_embeddings = + self.config.max_position_embeddings + if seq_len > original_max_position_embeddings: + if not hasattr(self, "long_inv_freq"): + self.long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + else: + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + + def dynamic_frequency_update(self, position_ids, device): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and + self.max_seq_len_cached > self.original_max_seq_len: + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @wraps(rope_forward) + def wrapper(self, x, position_ids): + if "dynamic" in self.rope_type: + dynamic_frequency_update(self, position_ids, device=x.device) + elif self.rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device) + return rope_forward(self, x, position_ids) + + return wrapper + """ def longrope_frequency_update(self, position_ids, device): @@ -565,21 +701,59 @@ def longrope_frequency_update(self, position_ids, device): # self.inv_freq = self.original_inv_freq def dynamic_frequency_update(self, position_ids, device): + # constructor: + # - self.max_seq_len_cached = config.max_position_embeddings + # - self.original_max_seq_len = config.max_position_embeddings + # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + + # This behaviour is difficult to translate. + # The sequence always grows. + # The test should always True. + # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len + # + # if seq_len > self.max_seq_len_cached: # growth + # inv_freq, self.attention_scaling = self.rope_init_fn( + # self.config, device, seq_len=seq_len + # ) + # self.register_buffer("inv_freq", inv_freq, persistent=False) + # self.max_seq_len_cached = seq_len + # + # So we should not need what follows. + # + # cond = (seq_len > self.max_seq_len_cached).item() + # self.attention_scaling = torch.cond( + # cond, + # (lambda x, y: x.clone()), + # (lambda x, y: y.clone()), + # [attention_scaling, self.attention_scaling], + # ) + seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.max_seq_len_cached = seq_len + long_inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len + ) - if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len - ): - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + # Second test to translate. + # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True. + # But in that case the following condition is a way to restore the original cache. + + # if ( + # seq_len < self.original_max_seq_len + # and self.max_seq_len_cached > self.original_max_seq_len + # ): + # self.original_inv_freq = self.original_inv_freq.to(device) + # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + # self.max_seq_len_cached = self.original_max_seq_len + + original_inv_freq = self.original_inv_freq.to(device) + cond = (seq_len >= self.original_max_seq_len).item() + inv_freq = torch.cond( + cond, + (lambda x, y: x.clone()), + (lambda x, y: y.clone()), + [long_inv_freq, original_inv_freq], + ) + self.inv_freq = inv_freq @wraps(rope_forward) def wrapper(self, x, position_ids): diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index a2d6ce6e..03453fbf 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -3953,6 +3953,46 @@ def _ccached_facebook_bart_large_cnn(): ) +def _ccached_microsoft_phi3_mini_4k_instruct(): + "microsoft/Phi-3-mini-4k-instruct" + return transformers.Phi3Config( + **{ + "_name_or_path": "Phi-3-mini-4k-instruct", + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM", + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.2", + "use_cache": true, + "attention_bias": false, + "vocab_size": 32064, + } + ) + + def _ccached_microsoft_phi4_reasoning(): "microsoft/Phi-4-mini-reasoning" return transformers.Phi3Config( From 528fa412865a3bed95d627b83a51f1e696333cb3 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 18:08:34 +0200 Subject: [PATCH 02/17] change --- .../test_validate_whole_models.py | 2 +- onnx_diagnostic/_command_lines_parser.py | 15 +++++++-- .../patches/patch_transformers.py | 32 ++++++++++++++++--- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index ff5d9da8..adc53951 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -265,7 +265,7 @@ def test_validate_phi3_mini_4k_instruct(self): mid, do_run=True, verbose=10, - exporter="custom", + exporter="onnx-dynamo", dump_folder="dump_test/validate_phi3_mini_4k_instruct", inputs2=True, patch=True, diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index f3515cb4..bbcdf47e 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -5,7 +5,7 @@ import sys import textwrap import onnx -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional, Union from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction from textwrap import dedent @@ -291,6 +291,14 @@ def _cmd_config(argv: List[Any]): print(f"task: {task_from_id(args.mid)}") +def _parse_json(value: str) -> Union[str, Dict[str, Any]]: + assert isinstance(value, str), f"value should be string but value={value!r}" + if value and value[0] == "{" and value[-1] == "}": + # a dictionary + return json.loads(value.replace("'", '"')) + return value + + class _ParseDict(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): d = getattr(namespace, self.dest) or {} @@ -314,7 +322,7 @@ def __call__(self, parser, namespace, values, option_string=None): continue except (TypeError, ValueError): pass - d[key] = value + d[key] = _parse_json(value) setattr(namespace, self.dest, d) @@ -430,7 +438,8 @@ def get_parser_validate() -> ArgumentParser: metavar="KEY=VALUE", nargs="*", help="Additional model options, use to change some parameters of the model, " - "example: --mop attn_implementation=eager", + "example: ``--mop attn_implementation=eager`` or " + "``--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"``", action=_ParseDict, ) parser.add_argument( diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 1be4be0a..4a7a809d 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -540,7 +540,7 @@ def patched__compute_dynamic_ntk_parameters( seq_len: Optional[int] = None, **rope_kwargs, ) -> Tuple["torch.Tensor", float]: - """ + """manual patch: ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]`` Computes the inverse frequencies with NTK scaling. @@ -594,8 +594,9 @@ def patched__compute_dynamic_ntk_parameters( seq_len = max_position_embeddings else: torch._check(isinstance(seq_len, torch.Tensor)) - seq_len = torch.max( - seq_len, torch.Tensor(max_position_embeddings, dtype=seq_len.dtype) + seq_len = torch.maximum( + seq_len, + torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), ) # Compute the inverse frequencies @@ -676,13 +677,23 @@ def wrapper(self, x, position_ids): """ def longrope_frequency_update(self, position_ids, device): + # It is no use to patch the function after the model is created + # as rope_init_fn is an attribute set to one function when the model + # is created and when no patch is applied yet. + # So we select the patched version here. + rope_init_fn = ( + patched__compute_dynamic_ntk_parameters + if self.rope_init_fn + is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters + else self.rope_init_fn + ) seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings else: original_max_position_embeddings = self.config.max_position_embeddings # At export time, seq_len is unknown. - long_inv_freq, _ = self.rope_init_fn( + long_inv_freq, _ = rope_init_fn( self.config, device, seq_len=original_max_position_embeddings + 1 ) original_inv_freq = self.original_inv_freq.to(device) @@ -706,6 +717,17 @@ def dynamic_frequency_update(self, position_ids, device): # - self.original_max_seq_len = config.max_position_embeddings # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + # It is no use to patch the function after the model is created + # as rope_init_fn is an attribute set to one function when the model + # is created and when no patch is applied yet. + # So we select the patched version here. + rope_init_fn = ( + patched__compute_dynamic_ntk_parameters + if self.rope_init_fn + is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters + else self.rope_init_fn + ) + # This behaviour is difficult to translate. # The sequence always grows. # The test should always True. @@ -729,7 +751,7 @@ def dynamic_frequency_update(self, position_ids, device): # ) seq_len = torch.max(position_ids) + 1 - long_inv_freq, self.attention_scaling = self.rope_init_fn( + long_inv_freq, self.attention_scaling = rope_init_fn( self.config, device, seq_len=seq_len ) From b187624974ccfbdc49bf6d6d1da9cc00d3445a22 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 18:57:17 +0200 Subject: [PATCH 03/17] fix --- .../ut_torch_models/test_validate_whole_models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index adc53951..d1d72223 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -259,18 +259,18 @@ def test_validate_model_vit_model(self): @hide_stdout() @ignore_warnings(FutureWarning) @requires_transformers("4.51") - def test_validate_phi3_mini_4k_instruct(self): - mid = "microsoft/Phi-3-mini-4k-instruct" + def test_validate_phi35_mini_instruct(self): + mid = "microsoft/Phi-3.5-mini-instruct" summary, data = validate_model( mid, do_run=True, verbose=10, - exporter="onnx-dynamo", - dump_folder="dump_test/validate_phi3_mini_4k_instruct", + exporter="custom", + dump_folder="dump_test/validate_phi35_mini_instruct", inputs2=True, patch=True, rewrite=True, - model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}}, + # model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}}, ) self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) From 905cd25c68f93a096bef2a6bb0be5706bfec27ce Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 20:18:15 +0200 Subject: [PATCH 04/17] custom patch --- .../onnx_export_errors.py | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index b11f5f94..ae5bf991 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -24,30 +24,26 @@ def get_function(name: str) -> Tuple["module", "function"]: # noqa: F821 @functools.lru_cache def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]: """Returns the list of patches to make for a specific module.""" - if isinstance(mod, list): - to_patch = mod - name = "list" - else: - to_patch = [] - for k in dir(mod): - if k.startswith("patched_"): - v = getattr(mod, k) - if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): - to_patch.append(v) - else: - # a function - doc = v.__doc__ - if doc.startswith("manual patch"): - continue - reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]") - fall = reg.findall(doc) - assert ( - len(fall) == 1 - ), f"Unable to find patching information for {v} in \n{doc}" - fmod, f = get_function(fall[0]) - to_patch.append({"module": fmod, "function": f, "patch": v}) - - name = mod.__name__ + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + else: + # a function + doc = v.__doc__ + if doc.startswith("manual patch"): + continue + reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]") + fall = reg.findall(doc) + assert ( + len(fall) == 1 + ), f"Unable to find patching information for {v} in \n{doc}" + fmod, f = get_function(fall[0]) + to_patch.append({"module": fmod, "function": f, "patch": v}) + + name = mod.__name__ return name, to_patch @@ -63,7 +59,11 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call :param verbose: verbosity :return: patch info """ - name, to_patch = get_patches(mod, verbose) + if isinstance(mod, list): + to_patch = mod + name = "list" + else: + name, to_patch = get_patches(mod, verbose) res = {} for cls in to_patch: @@ -98,7 +98,12 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo :param mod: module of list of clsses to patch :param verbose: verbosity """ - name, to_patch = get_patches(mod, verbose) + if isinstance(mod, list): + to_patch = mod + name = "list" + else: + name, to_patch = get_patches(mod, verbose) + set_patch_cls = {i for i in to_patch if not isinstance(i, dict)} dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)} From 371b55d444f6be31e752b8d5cd35e47546911754 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 20:46:03 +0200 Subject: [PATCH 05/17] patch --- onnx_diagnostic/torch_export_patches/onnx_export_errors.py | 2 +- .../torch_export_patches/patches/patch_transformers.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index ae5bf991..eccbd8c1 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -32,7 +32,7 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]: to_patch.append(v) else: # a function - doc = v.__doc__ + doc = v.__doc__.lstrip() if doc.startswith("manual patch"): continue reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]") diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 4a7a809d..d5f10355 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -540,7 +540,8 @@ def patched__compute_dynamic_ntk_parameters( seq_len: Optional[int] = None, **rope_kwargs, ) -> Tuple["torch.Tensor", float]: - """manual patch: + """ + manual patch: ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]`` Computes the inverse frequencies with NTK scaling. From 4c6517585ba83d837f1d0ddd3afa1fbfb15fd63d Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 21:05:43 +0200 Subject: [PATCH 06/17] doc --- _doc/conf.py | 1 + onnx_diagnostic/torch_export_patches/onnx_export_errors.py | 6 ++---- .../torch_export_patches/patches/patch_transformers.py | 2 ++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/_doc/conf.py b/_doc/conf.py index f0548f08..1e989447 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -138,6 +138,7 @@ def linkcode_resolve(domain, info): ("py:class", "transformers.cache_utils.SlidingWindowCache"), ("py:class", "transformers.configuration_utils.PretrainedConfig"), ("py:class", "transformers.modeling_outputs.BaseModelOutput"), + ("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"), ("py:func", "torch.export._draft_export.draft_export"), ("py:func", "torch._export.tools.report_exportability"), ("py:func", "torch.utils._pytree.register_pytree_node"), diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index eccbd8c1..ad1309e6 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -10,10 +10,8 @@ from .patches import patch_transformers as patch_transformers_list -def get_function(name: str) -> Tuple["module", "function"]: # noqa: F821 - """ - Returns the module and the function based on its name. - """ +def get_function(name: str) -> Tuple[type, Callable]: + """Returns the module and the function based on its name.""" spl = name.split(".") module_name = ".".join(spl[:-1]) fname = spl[-1] diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index d5f10355..a7553b0c 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -546,6 +546,7 @@ def patched__compute_dynamic_ntk_parameters( Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + Args: config ([`~transformers.PretrainedConfig`]): The model configuration. @@ -557,6 +558,7 @@ def patched__compute_dynamic_ntk_parameters( rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the From cb24a5bbce8d92cede8bcfb8c4493be483f3fe6d Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 21:18:24 +0200 Subject: [PATCH 07/17] file --- .../test_validate_whole_models.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index d1d72223..cc9981c3 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -270,7 +270,6 @@ def test_validate_phi35_mini_instruct(self): inputs2=True, patch=True, rewrite=True, - # model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}}, ) self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) @@ -279,6 +278,29 @@ def test_validate_phi35_mini_instruct(self): op_types = set(n.op_type for n in onx.graph.node) self.assertIn("If", op_types) + @requires_torch("2.7") + @hide_stdout() + @ignore_warnings(FutureWarning) + @requires_transformers("4.51") + def test_validate_phi35_4k_mini_instruct(self): + mid = "microsoft/Phi-3.5-mini-4k-instruct" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="custom", + dump_folder="dump_test/validate_phi35_mini_instruct", + inputs2=True, + patch=True, + rewrite=True, + model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}}, + ) + self.assertIsInstance(summary, dict) + self.assertIsInstance(data, dict) + onnx_filename = data["onnx_filename"] + onx = onnx.load(onnx_filename) + op_types = set(n.op_type for n in onx.graph.node) + self.assertIn("If", op_types) if __name__ == "__main__": unittest.main(verbosity=2) From 2e1104410c8af1aba8e5bb17597757f4a2a36d99 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 12 Jun 2025 21:21:01 +0200 Subject: [PATCH 08/17] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c34f031f..f73c394e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.0 +++++ +* :pr:`145`: patch for _compute_dynamic_ntk_parameters (Phi3RotaryEmbedding) * :pr:`144`: support for second inputs with different dimension, rename test_helper into validate, support ``interpolate_pos_encoding`` for ``VitModel``, From 1212c696186ea0851b1300e82eada3117bfbf50e Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 11:19:33 +0200 Subject: [PATCH 09/17] fix unittest --- _unittests/ut_helpers/test_helper.py | 2 +- _unittests/ut_torch_models/test_hghub_api.py | 6 +++-- .../ut_torch_models/test_hghub_model.py | 6 ++--- .../test_validate_whole_models.py | 2 +- onnx_diagnostic/torch_models/hghub/hub_api.py | 22 +++++++++++++++---- .../torch_models/hghub/model_inputs.py | 3 +++ 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index b42ed072..70abc638 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -584,7 +584,7 @@ def test_flatten_encoder_decoder_cache(self): self.assertIn("EncoderDecoderCache", s) def test_string_typeçconfig(self): - conf = get_pretrained_config("microsoft/phi-2") + conf = get_pretrained_config("microsoft/phi-2", use_only_preinstalled=True) s = string_type(conf) self.assertStartsWith("PhiConfig(**{", s) diff --git a/_unittests/ut_torch_models/test_hghub_api.py b/_unittests/ut_torch_models/test_hghub_api.py index 429099df..3766560b 100644 --- a/_unittests/ut_torch_models/test_hghub_api.py +++ b/_unittests/ut_torch_models/test_hghub_api.py @@ -72,14 +72,16 @@ def test_task_from_id_long(self): @requires_torch("2.7") @hide_stdout() def test_get_pretrained_config(self): - conf = get_pretrained_config("microsoft/phi-2") + conf = get_pretrained_config("microsoft/phi-2", use_only_preinstalled=True) self.assertNotEmpty(conf) @requires_transformers("4.50") @requires_torch("2.7") @hide_stdout() def test_get_pretrained_config_options(self): - conf = get_pretrained_config("microsoft/phi-2", num_key_value_heads=16) + conf = get_pretrained_config( + "microsoft/phi-2", num_key_value_heads=16, use_only_preinstalled=True + ) self.assertNotEmpty(conf) self.assertEqual(conf.num_key_value_heads, 16) diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 7ea6fa80..1e644cad 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -129,11 +129,11 @@ def _diff(c1, c2): try: model(**inputs) except Exception as e: - diff = _diff(get_pretrained_config(mid), data["configuration"]) + cf = get_pretrained_config(mid, use_only_preinstalled=True) + diff = _diff(cf, data["configuration"]) raise AssertionError( f"Computation failed due to {e}.\n--- pretrained\n" - f"{pprint.pformat(get_pretrained_config(mid))}\n" - f"--- modified\n{data['configuration']}\n" + f"{pprint.pformat(cf)}\n--- modified\n{data['configuration']}\n" f"--- diff\n{diff}" ) from e # different expected value for different version of transformers diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index 0cadca23..d2bac85c 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -283,7 +283,7 @@ def test_validate_phi35_mini_instruct(self): @ignore_warnings(FutureWarning) @requires_transformers("4.51") def test_validate_phi35_4k_mini_instruct(self): - mid = "microsoft/Phi-3.5-mini-4k-instruct" + mid = "microsoft/Phi-3-mini-4k-instruct" summary, data = validate_model( mid, do_run=True, diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index 6067b120..e071b1b0 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -2,6 +2,7 @@ import functools import json import os +import pprint from typing import Any, Dict, List, Optional, Union import transformers from huggingface_hub import HfApi, model_info, hf_hub_download @@ -33,10 +34,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig return res -def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.PretrainedConfig]: +def get_cached_configuration( + name: str, exc: bool = False, **kwargs +) -> Optional[transformers.PretrainedConfig]: """ Returns cached configuration to avoid having to many accesses to internet. It returns None if not Cache. The list of cached models follows. + If *exc* is True or if environment variable ``NOHTTP`` is defined, + the function raises an exception if *name* is not found. .. runpython:: @@ -54,8 +59,9 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr conf = copy.deepcopy(conf) update_config(conf, kwargs) return conf - if os.environ.get("NOHTTP", ""): - raise AssertionError(f"Unable to find {name!r} in {sorted(cached)}") + assert not exc and not os.environ.get( + "NOHTTP", "" + ), f"Unable to find {name!r} in {pprint.pformat(sorted(cached))}" return None @@ -64,6 +70,7 @@ def get_pretrained_config( trust_remote_code: bool = True, use_preinstalled: bool = True, subfolder: Optional[str] = None, + use_only_preinstalled: bool = False, **kwargs, ) -> Any: """ @@ -77,13 +84,20 @@ def get_pretrained_config( :func:`get_cached_configuration`, the cached list is mostly for unit tests :param subfolder: subfolder for the given model id + :param use_only_preinstalled: if True, raises an exception if not preinstalled :param kwargs: additional kwargs :return: a configuration """ if use_preinstalled: - conf = get_cached_configuration(model_id, subfolder=subfolder, **kwargs) + conf = get_cached_configuration( + model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs + ) if conf is not None: return conf + assert not use_only_preinstalled, ( + f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, " + f"use_preinstalled={use_preinstalled!r}" + ) if subfolder: try: return transformers.AutoConfig.from_pretrained( diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index b6d08720..8b8b42da 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -26,6 +26,7 @@ def get_untrained_model_with_inputs( use_preinstalled: bool = True, add_second_input: bool = False, subfolder: Optional[str] = None, + use_only_preinstalled: bool = True, ) -> Dict[str, Any]: """ Gets a non initialized model similar to the original model @@ -46,6 +47,7 @@ def get_untrained_model_with_inputs( :param add_second_input: provides a second inputs to check a model supports different shapes :param subfolder: subfolder to use for this model id + :param use_only_preinstalled: use only preinstalled version :return: dictionary with a model, inputs, dynamic shapes, and the configuration, some necessary rewriting as well @@ -74,6 +76,7 @@ def get_untrained_model_with_inputs( config = get_pretrained_config( model_id, use_preinstalled=use_preinstalled, + use_only_preinstalled=use_only_preinstalled, subfolder=subfolder, **(model_kwargs or {}), ) From a9b2581320886d1c4db8714af0b7234b77ae894a Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 12:52:12 +0200 Subject: [PATCH 10/17] badge --- README.rst | 2 +- _doc/index.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index f9cdf089..34b7f2ee 100644 --- a/README.rst +++ b/README.rst @@ -22,7 +22,7 @@ onnx-diagnostic: investigate onnx models .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/psf/black -.. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/branch/main/graph/badge.svg?token=Wb9ZGDta8J +.. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/graph/badge.svg?token=91T5ZVIP96 :target: https://codecov.io/gh/sdpython/onnx-diagnostic The main feature is about `patches `_: diff --git a/_doc/index.rst b/_doc/index.rst index 4b42d384..11f3a9fc 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -15,7 +15,7 @@ onnx-diagnostic: investigate onnx models .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/psf/black -.. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/branch/main/graph/badge.svg?token=Wb9ZGDta8J +.. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/graph/badge.svg?token=91T5ZVIP96 :target: https://codecov.io/gh/sdpython/onnx-diagnostic The main feature is about `patches `_: From f95e299cbb529db0e1a2ad49d1b8a64db6537536 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 13:16:21 +0200 Subject: [PATCH 11/17] remove one unit test --- _unittests/ut_torch_models/test_hghub_model.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 1e644cad..1b051a50 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -75,16 +75,6 @@ def test_get_untrained_model_with_inputs_beit(self): # different expected value for different version of transformers self.assertIn((data["size"], data["n_weights"]), [(111448, 27862), (56880, 14220)]) - @hide_stdout() - @ignore_errors(OSError) - def test_get_untrained_model_with_inputs_codellama(self): - mid = "codellama/CodeLlama-7b-Python-hf" - data = get_untrained_model_with_inputs(mid, verbose=1) - model, inputs = data["model"], data["inputs"] - model(**inputs) - # different expected value for different version of transformers - self.assertIn((data["size"], data["n_weights"]), [(547377152, 136844288)]) - @hide_stdout() @ignore_errors(OSError) def test_get_untrained_model_with_inputs_clip_vit(self): From b97f9fafe3a32384af5d80d7c337b660a2a44577 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 13:24:42 +0200 Subject: [PATCH 12/17] add more cache --- .../hghub/hub_data_cached_configs.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 03453fbf..71a1042a 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4160,3 +4160,63 @@ def _ccached_ydshieh_tiny_random_vit_for_image_classification(): "transformers_version": "4.24.0.dev0", } ) + + +def _ccached_huggingface_tiny_random_idefics(): + "HuggingFaceM4/tiny-random-idefics" + return transformers.Phi3Config( + **{ + "additional_vocab_size": 2, + "alpha_initializer": "ones", + "alpha_type": "vector", + "alphas_initializer_range": 0.0, + "architectures": ["IdeficsForVisionText2Text"], + "bos_token_id": 1, + "cross_layer_activation_function": "swiglu", + "cross_layer_interval": 1, + "dropout": 0.0, + "eos_token_id": 2, + "ffn_dim": 64, + "freeze_lm_head": false, + "freeze_text_layers": false, + "freeze_text_module_exceptions": [], + "freeze_vision_layers": false, + "freeze_vision_module_exceptions": [], + "hidden_act": "silu", + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_new_tokens": 128, + "max_position_embeddings": 128, + "model_type": "idefics", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "pad_token_id": 0, + "qk_layer_norms": false, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.27.0.dev0", + "use_cache": true, + "use_resampler": true, + "vocab_size": 32000, + "word_embed_proj_dim": 16, + "vision_config": { + "hidden_act": "gelu", + "embed_dim": 32, + "image_size": 30, + "intermediate_size": 37, + "patch_size": 2, + "num_attention_heads": 4, + "num_hidden_layers": 5, + "vision_model_name": "hf-internal-testing/tiny-random-clip", + }, + "perceiver_config": { + "qk_layer_norms_perceiver": false, + "resampler_depth": 2, + "resampler_head_dim": 8, + "resampler_n_heads": 2, + "resampler_n_latents": 16, + }, + } + ) From ca4eb76c4d613569a96315b821b1ce1e9b17eeb9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 14:49:27 +0200 Subject: [PATCH 13/17] more patches --- CHANGELOGS.rst | 1 + _doc/examples/plot_export_hub_codellama.py | 8 +- .../ut_tasks/test_tasks_image_text_to_text.py | 13 +- .../patches/patch_transformers.py | 149 ++++++++++++++++++ onnx_diagnostic/torch_models/hghub/hub_api.py | 8 +- .../hghub/hub_data_cached_configs.py | 60 ------- .../torch_models/hghub/model_inputs.py | 2 +- 7 files changed, 171 insertions(+), 70 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index f73c394e..d7df2d5b 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.0 +++++ +* :pr:`146`: patch for IdeficsAttention, IdeficsEmbedding * :pr:`145`: patch for _compute_dynamic_ntk_parameters (Phi3RotaryEmbedding) * :pr:`144`: support for second inputs with different dimension, rename test_helper into validate, diff --git a/_doc/examples/plot_export_hub_codellama.py b/_doc/examples/plot_export_hub_codellama.py index bf932afc..d90df42d 100644 --- a/_doc/examples/plot_export_hub_codellama.py +++ b/_doc/examples/plot_export_hub_codellama.py @@ -20,6 +20,7 @@ import pprint import torch from onnx_diagnostic import doc +from onnx_diagnostic.ext_test_case import unit_test_going from onnx_diagnostic.helpers import string_type from onnx_diagnostic.torch_models.hghub import ( get_untrained_model_with_inputs, @@ -32,7 +33,12 @@ from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -model_id = "codellama/CodeLlama-7b-Python-hf" +model_id = ( + "HuggingFaceM4/tiny-random-idefics" + if unit_test_going() + else "codellama/CodeLlama-7b-Python-hf" +) +print(f"model_id={model_id!r}") print("info", get_model_info(model_id)) # %% diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 731371ef..173d628c 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -1,6 +1,11 @@ import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers, has_torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -8,6 +13,8 @@ class TestTasks(ExtTestCase): @hide_stdout() + @requires_transformers("4.52") + @requires_torch("2.7.99") def test_image_text_to_text(self): mid = "HuggingFaceM4/tiny-random-idefics" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) @@ -16,10 +23,6 @@ def test_image_text_to_text(self): model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] model(**inputs) model(**data["inputs2"]) - if not has_transformers("4.55"): - raise unittest.SkipTest("The model has control flow.") - if not has_torch("2.7.99"): - raise unittest.SkipTest("sym_max does not work with dynamic dimension") with torch_export_patches(patch_transformers=True, verbose=10): torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index a7553b0c..a650feb3 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -818,3 +818,152 @@ def forward(self, x, position_ids): sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class patched_IdeficsEmbedding(torch.nn.Module): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # if seq_len > self.max_seq_len_cached: + # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached): + t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(x.dtype), emb.sin().to(x.dtype) + + def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached): + torch._check(seq_len.item() <= cos_cached.shape[0]) + co = cos_cached[: seq_len.item()].detach().clone() + torch._check(seq_len.item() <= sin_cached.shape[0]) + si = sin_cached[: seq_len.item()].detach().clone() + return co.to(dtype=x.dtype), si.to(dtype=x.dtype) + + cos_cached, sin_cached = torch.cond( + (seq_len > self.max_seq_len_cached).item(), + _set_cos_sin_cache_then, + _set_cos_sin_cache_else, + [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached], + ) + return cos_cached, sin_cached + + +class patched_IdeficsAttention(torch.nn.Module): + _PATCHES_ = ["forward"] + _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # if key_value_states are provided this layer is used as a cross-attention layer + is_cross_attention = self.is_cross_attention or key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + if not is_cross_attention: + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + else: + _, kv_len, _ = ( + key_value_states.size() + ) # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = ( + self.k_proj(key_value_states) + .view(bsz, kv_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(key_value_states) + .view(bsz, kv_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + + if not is_cross_attention: + rotary_length = torch.maximum( + torch.tensor(kv_seq_len, dtype=torch.int64), + torch.tensor(q_len, dtype=torch.int64), + ) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_length) + query_states, key_states = ( + transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # sin and cos are specific to RoPE models; + # cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if self.qk_layer_norms: + query_states = self.q_layer_norm(query_states) + key_states = self.k_layer_norm(key_states) + + attention_interface: Callable = ( + transformers.models.idefics.modeling_idefics.eager_attention_forward + ) + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + transformers.models.idefics.modeling_idefics.logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to " + "eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index e071b1b0..8b58b4ed 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -59,9 +59,11 @@ def get_cached_configuration( conf = copy.deepcopy(conf) update_config(conf, kwargs) return conf - assert not exc and not os.environ.get( - "NOHTTP", "" - ), f"Unable to find {name!r} in {pprint.pformat(sorted(cached))}" + assert not exc and not os.environ.get("NOHTTP", ""), ( + f"Unable to find {name!r} (exc={exc}, " + f"NOHTTP={os.environ.get('NOHTTP', '')!r}) " + f"in {pprint.pformat(sorted(cached))}" + ) return None diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 71a1042a..03453fbf 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4160,63 +4160,3 @@ def _ccached_ydshieh_tiny_random_vit_for_image_classification(): "transformers_version": "4.24.0.dev0", } ) - - -def _ccached_huggingface_tiny_random_idefics(): - "HuggingFaceM4/tiny-random-idefics" - return transformers.Phi3Config( - **{ - "additional_vocab_size": 2, - "alpha_initializer": "ones", - "alpha_type": "vector", - "alphas_initializer_range": 0.0, - "architectures": ["IdeficsForVisionText2Text"], - "bos_token_id": 1, - "cross_layer_activation_function": "swiglu", - "cross_layer_interval": 1, - "dropout": 0.0, - "eos_token_id": 2, - "ffn_dim": 64, - "freeze_lm_head": false, - "freeze_text_layers": false, - "freeze_text_module_exceptions": [], - "freeze_vision_layers": false, - "freeze_vision_module_exceptions": [], - "hidden_act": "silu", - "hidden_size": 16, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_new_tokens": 128, - "max_position_embeddings": 128, - "model_type": "idefics", - "num_attention_heads": 4, - "num_hidden_layers": 2, - "pad_token_id": 0, - "qk_layer_norms": false, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": false, - "torch_dtype": "float16", - "transformers_version": "4.27.0.dev0", - "use_cache": true, - "use_resampler": true, - "vocab_size": 32000, - "word_embed_proj_dim": 16, - "vision_config": { - "hidden_act": "gelu", - "embed_dim": 32, - "image_size": 30, - "intermediate_size": 37, - "patch_size": 2, - "num_attention_heads": 4, - "num_hidden_layers": 5, - "vision_model_name": "hf-internal-testing/tiny-random-clip", - }, - "perceiver_config": { - "qk_layer_norms_perceiver": false, - "resampler_depth": 2, - "resampler_head_dim": 8, - "resampler_n_heads": 2, - "resampler_n_latents": 16, - }, - } - ) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 8b8b42da..bff1ef75 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -26,7 +26,7 @@ def get_untrained_model_with_inputs( use_preinstalled: bool = True, add_second_input: bool = False, subfolder: Optional[str] = None, - use_only_preinstalled: bool = True, + use_only_preinstalled: bool = False, ) -> Dict[str, Any]: """ Gets a non initialized model similar to the original model From 7ae4d58589fe18480b76a5e904d9738f581185c0 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 15:02:53 +0200 Subject: [PATCH 14/17] fix --- _unittests/ut_xrun_doc/test_documentation_examples.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 3cfd62a4..bbf768f5 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -90,10 +90,17 @@ def add_test_methods(cls): ): reason = "transformers<4.51" + if ( + not reason + and name in {"plot_export_hub_codellama.py"} + and not has_transformers("4.52") + ): + reason = "transformers<4.52" + if ( not reason and name in {"plot_export_locate_issue.py", "plot_export_with_auto.py"} - and not has_torch("4.7") + and not has_torch("2.7") ): reason = "torch<2.7" From 1eb4ee8bd1883880b135e80e1e6e09550dc26dd4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 15:10:40 +0200 Subject: [PATCH 15/17] fix --- _unittests/ut_xrun_doc/test_documentation_examples.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index bbf768f5..b9505e51 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -100,9 +100,9 @@ def add_test_methods(cls): if ( not reason and name in {"plot_export_locate_issue.py", "plot_export_with_auto.py"} - and not has_torch("2.7") + and not has_torch("2.8") ): - reason = "torch<2.7" + reason = "torch<2.8" if reason: From 9d9ed62e6cc2d4f0870d317bfbfd4408be4f06e4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 15:18:57 +0200 Subject: [PATCH 16/17] atol --- _unittests/ut_helpers/test_doc_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_helpers/test_doc_helper.py b/_unittests/ut_helpers/test_doc_helper.py index 76ef61bc..7d99f599 100644 --- a/_unittests/ut_helpers/test_doc_helper.py +++ b/_unittests/ut_helpers/test_doc_helper.py @@ -56,7 +56,7 @@ def test_custom_doc_kernels_layer_normalization(self): ) expected = torch_sess.run(None, feeds) got = torch_sess_custom.run(None, feeds) - self.assertEqualAny(expected, got, atol=1e-3) + self.assertEqualAny(expected, got, atol=2e-3) def test_custom_doc_kernels_matmul(self): model = oh.make_model( From 13a9b8a1098f55d2283d769c9b109ec8c25e0f11 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 14 Jun 2025 15:28:50 +0200 Subject: [PATCH 17/17] fix --- _unittests/ut_xrun_doc/test_documentation_examples.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index b9505e51..3179969a 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -99,7 +99,12 @@ def add_test_methods(cls): if ( not reason - and name in {"plot_export_locate_issue.py", "plot_export_with_auto.py"} + and name + in { + "plot_export_locate_issue.py", + "plot_export_with_auto.py", + "plot_export_hub_codellama.py", + } and not has_torch("2.8") ): reason = "torch<2.8"