From afebca5e3e9816803913738c4270be4d3012ccbb Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Oct 2025 15:55:20 +0200 Subject: [PATCH 1/8] allow to convert a model with other task inputs --- onnx_diagnostic/tasks/text_generation.py | 6 +++++ .../torch_models/hghub/model_inputs.py | 8 ++++--- onnx_diagnostic/torch_models/validate.py | 24 ++++++++++++------- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index db930b08..9643dc34 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -19,6 +19,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" # FalconMambaConfig: use_mambapy + if hasattr(config, "text_config"): + # The model is probably of mixture of models used only for text. + config = config.text_config check_hasattr( config, ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"), @@ -308,6 +311,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: If the configuration is None, the function selects typical dimensions. """ + if hasattr(config, "text_config"): + # The model is probably of mixture of models used only for text. + config = config.text_config if config is not None: check_hasattr( config, diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 34e6bee2..47c1a50a 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -120,14 +120,16 @@ def get_untrained_model_with_inputs( **(model_kwargs or {}), ) - model, task, mkwargs, diff_config = None, None, {}, None + model, task_, mkwargs, diff_config = None, None, {}, None if use_pretrained and same_as_pretrained: if model_id in HANDLED_MODELS: - model, task, config = load_specific_model(model_id, verbose=verbose) + model, task_, config = load_specific_model(model_id, verbose=verbose) + if task is None: + task = task_ if model is None: arch = architecture_from_config(config) - if arch is None: + if task is None and arch is None: task = task_from_id(model_id, subfolder=subfolder) assert task is not None or arch is not None, ( f"Unable to determine the architecture for model {model_id!r}, " diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index dda71c9b..e5cebfea 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -117,11 +117,21 @@ def _make_folder_name( drop_inputs: Optional[List[str]] = None, same_as_pretrained: bool = False, use_pretrained: bool = False, + task: Optional[str] = None, ) -> str: "Creates a filename unique based on the given options." els = [model_id.replace("/", "_")] if subfolder: els.append(subfolder.replace("/", "_")) + if not task: + els.append(task) + if drop_inputs: + ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs) + els.append(f"I-{ii.upper()}") + if use_pretrained: + els.append("TRAINED") + elif same_as_pretrained: + els.append("SAMESIZE") if exporter: els.append(exporter) if optimization: @@ -142,14 +152,7 @@ def _make_folder_name( els.append(sdev) if opset is not None: els.append(f"op{opset}") - if drop_inputs: - ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs) - els.append(f"I-{ii.upper()}") - if use_pretrained: - els.append("TRAINED") - elif same_as_pretrained: - els.append("SAMESIZE") - return "-".join(els) + return "/".join(els) def version_summary() -> Dict[str, Union[int, float, str]]: @@ -476,6 +479,7 @@ def validate_model( drop_inputs=drop_inputs, use_pretrained=use_pretrained, same_as_pretrained=same_as_pretrained, + task=task, ) dump_folder = os.path.join(dump_folder, folder_name) if not os.path.exists(dump_folder): @@ -490,6 +494,8 @@ def validate_model( print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}") else: print(f"[validate_model] validate model id {model_id!r}") + if task: + print(f"[validate_model] with task {task!r}") print(f"[validate_model] patch={patch!r}") if model_options: print(f"[validate_model] model_options={model_options!r}") @@ -765,6 +771,7 @@ def validate_model( ep = data["exported_program"] if verbose: print(f"[validate_model] -- dumps exported program in {dump_folder!r}...") + folder_name = folder_name.replace("/", "-") with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f: f.write(str(ep)) torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2")) @@ -773,6 +780,7 @@ def validate_model( if verbose: print("[validate_model] done (dump ep)") if "onnx_program" in data: + folder_name = folder_name.replace("/", "-") epo = data["onnx_program"] if verbose: print(f"[validate_model] dumps onnx program in {dump_folder!r}...") From 2f132c117b01d421c467dc1c4544b4df5ed947ec Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Oct 2025 15:58:06 +0200 Subject: [PATCH 2/8] doc --- CHANGELOGS.rst | 1 + _doc/conf.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 6f9f39c9..7a90b688 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.15 ++++++ +* :pr:`264`: allows to validate a model with inputs defined from another task * :pr:`261`: updates to support ``transformers>=5.0`` 0.7.14 diff --git a/_doc/conf.py b/_doc/conf.py index 86c38d63..e30afe02 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -229,7 +229,7 @@ def linkcode_resolve(domain, info): "Linux": "https://www.linux.org/", "ml_dtypes": "https://github.com/jax-ml/ml_dtypes", "ModelBuilder": "https://onnxruntime.ai/docs/genai/howto/build-model.html", - "monai": "https://monai.io/", + "monai": "https://github.com/Project-MONAI/MONAI", "numpy": "https://numpy.org/", "onnx": "https://onnx.ai/onnx/", "onnx-ir": "https://github.com/onnx/ir-py", From 805107a86e1648d7f8b7b482aef4eaefa4b1ef81 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Oct 2025 16:01:43 +0200 Subject: [PATCH 3/8] pym --- onnx_diagnostic/torch_models/validate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index e5cebfea..f85eda12 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -124,7 +124,7 @@ def _make_folder_name( if subfolder: els.append(subfolder.replace("/", "_")) if not task: - els.append(task) + els.append(task) # type: ignore[arg-type] if drop_inputs: ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs) els.append(f"I-{ii.upper()}") @@ -771,6 +771,9 @@ def validate_model( ep = data["exported_program"] if verbose: print(f"[validate_model] -- dumps exported program in {dump_folder!r}...") + assert isinstance( + folder_name, str + ), f"folder_name={folder_name!r} should be a string" folder_name = folder_name.replace("/", "-") with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f: f.write(str(ep)) @@ -780,6 +783,9 @@ def validate_model( if verbose: print("[validate_model] done (dump ep)") if "onnx_program" in data: + assert isinstance( + folder_name, str + ), f"folder_name={folder_name!r} should be a string" folder_name = folder_name.replace("/", "-") epo = data["onnx_program"] if verbose: From c2916aab6f753e7355bf8e970a80f7b3bd0192fa Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Oct 2025 16:20:44 +0200 Subject: [PATCH 4/8] fix --- onnx_diagnostic/torch_models/hghub/model_inputs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 47c1a50a..cd12b4b5 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -95,6 +95,8 @@ def get_untrained_model_with_inputs( print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes'])) print("-- configuration:", pprint.pformat(data['configuration'])) """ + if task == "": + task = None assert not use_preinstalled or not use_only_preinstalled, ( f"model_id={model_id!r}, preinstalled model is only available " f"if use_only_preinstalled is False." From 29feec4a2c147613b50a4570f6d37f99588d19d7 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 17 Oct 2025 16:37:54 +0200 Subject: [PATCH 5/8] none --- _unittests/ut_torch_models/test_validate_whole_models1.py | 2 +- onnx_diagnostic/torch_models/validate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_torch_models/test_validate_whole_models1.py b/_unittests/ut_torch_models/test_validate_whole_models1.py index 7d8f169d..5fe52e61 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models1.py +++ b/_unittests/ut_torch_models/test_validate_whole_models1.py @@ -98,7 +98,7 @@ def test_f_validate_model_onnx_dynamo_ir(self): @requires_torch("2.7") @requires_onnxscript("0.7") @hide_stdout() - @ignore_warnings(FutureWarning) + @ignore_warnings((FutureWarning, RuntimeWarning)) def test_g_validate_model_onnx_dynamo_os_ort(self): mid = "arnir0/Tiny-LLM" summary, data = validate_model( diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index f85eda12..b10aa577 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -152,7 +152,7 @@ def _make_folder_name( els.append(sdev) if opset is not None: els.append(f"op{opset}") - return "/".join(els) + return "/".join([e for e in els if e]) def version_summary() -> Dict[str, Union[int, float, str]]: From fbbf946833b6e662995ceeefdbe761590bdc960d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 17 Oct 2025 17:14:40 +0200 Subject: [PATCH 6/8] fix rope --- .../patches/patch_transformers.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 0e1ae85b..6cc4651d 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1019,6 +1019,25 @@ def patched__compute_dynamic_ntk_parameters( return inv_freq, attention_factor +def _get_rope_init_fn(self) -> Callable: + if hasattr(self, "rope_init_fn"): + # transformers<=5.0 + rope_init_fn = ( + patched__compute_dynamic_ntk_parameters + if self.rope_init_fn + is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters + else self.rope_init_fn + ) + return rope_init_fn + + rope_init_fn = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type] + if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters: + return patched__compute_dynamic_ntk_parameters + return rope_init_fn + + def patched_dynamic_rope_update(rope_forward): """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]`` @@ -1087,12 +1106,7 @@ def longrope_frequency_update(self, position_ids, device): # as rope_init_fn is an attribute set to one function when the model # is created and when no patch is applied yet. # So we select the patched version here. - rope_init_fn = ( - patched__compute_dynamic_ntk_parameters - if self.rope_init_fn - is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters - else self.rope_init_fn - ) + rope_init_fn = _get_rope_init_fn(self) seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings @@ -1128,12 +1142,7 @@ def dynamic_frequency_update(self, position_ids, device): # as rope_init_fn is an attribute set to one function when the model # is created and when no patch is applied yet. # So we select the patched version here. - rope_init_fn = ( - patched__compute_dynamic_ntk_parameters - if self.rope_init_fn - is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters - else self.rope_init_fn - ) + rope_init_fn = _get_rope_init_fn(self) # This behaviour is difficult to translate. # The sequence always grows. From deb39ecec9d8ba22da448f0e8ba96e9b971c83c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 17 Oct 2025 17:35:02 +0200 Subject: [PATCH 7/8] fix optional --- onnx_diagnostic/helpers/config_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index 02a025b2..262e11fc 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ mod_name = cls.__module__ mod = importlib.import_module(mod_name) source = inspect.getsource(mod) - reg = re.compile("config: ([A-Za-z0-9]+)") + # [^O] avoids capturing Optional[Something] + reg = re.compile("config: ([^O][A-Za-z0-9]+)") fall = reg.findall(source) if len(fall) == 0: assert not exc, ( From d19c66b68a9a807ba0c0cd65dfe492d3d83b8a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 17 Oct 2025 18:16:34 +0200 Subject: [PATCH 8/8] fix rotary patch --- .../patches/patch_transformers.py | 79 ++++++++++++++----- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 6cc4651d..ac1b7270 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1019,7 +1019,7 @@ def patched__compute_dynamic_ntk_parameters( return inv_freq, attention_factor -def _get_rope_init_fn(self) -> Callable: +def _get_rope_init_fn(self, layer_type=None) -> Callable: if hasattr(self, "rope_init_fn"): # transformers<=5.0 rope_init_fn = ( @@ -1030,8 +1030,9 @@ def _get_rope_init_fn(self) -> Callable: ) return rope_init_fn + rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] rope_init_fn = self.compute_default_rope_parameters - if self.rope_type != "default": + if rope_type != "default": rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type] if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters: return patched__compute_dynamic_ntk_parameters @@ -1101,17 +1102,27 @@ def wrapper(self, x, position_ids): """ - def longrope_frequency_update(self, position_ids, device): + def longrope_frequency_update(self, position_ids, device, layer_type=None): # It is no use to patch the function after the model is created # as rope_init_fn is an attribute set to one function when the model # is created and when no patch is applied yet. # So we select the patched version here. - rope_init_fn = _get_rope_init_fn(self) + rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings else: original_max_position_embeddings = self.config.max_position_embeddings + + if layer_type is None: + # rope_type = self.rope_type + original_inv_freq = self.original_inv_freq + prefix = "" + else: + # rope_type = self.rope_type[layer_type] + original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") + prefix = f"{layer_type}_" + # At export time, seq_len is unknown. long_inv_freq, _ = rope_init_fn( self.config, device, seq_len=original_max_position_embeddings + 1 @@ -1126,13 +1137,13 @@ def longrope_frequency_update(self, position_ids, device): (lambda x, y: y.clone()), [long_inv_freq, original_inv_freq], ) - self.inv_freq = inv_freq + setattr(self, f"{prefix}inv_freq", inv_freq) # if seq_len > original_max_position_embeddings: # self.inv_freq = self.long_inv_freq # else: # self.inv_freq = self.original_inv_freq - def dynamic_frequency_update(self, position_ids, device): + def dynamic_frequency_update(self, position_ids, device, layer_type=None): # constructor: # - self.max_seq_len_cached = config.max_position_embeddings # - self.original_max_seq_len = config.max_position_embeddings @@ -1142,7 +1153,7 @@ def dynamic_frequency_update(self, position_ids, device): # as rope_init_fn is an attribute set to one function when the model # is created and when no patch is applied yet. # So we select the patched version here. - rope_init_fn = _get_rope_init_fn(self) + rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) # This behaviour is difficult to translate. # The sequence always grows. @@ -1171,6 +1182,19 @@ def dynamic_frequency_update(self, position_ids, device): self.config, device, seq_len=seq_len ) + if layer_type is None: + # rope_type = self.rope_type + # max_seq_len_cached = self.max_seq_len_cached + original_inv_freq = self.original_inv_freq + prefix = "" + else: + # rope_type = self.rope_type[layer_type] + # max_seq_len_cached = getattr( + # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached + # ) + original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") + prefix = f"{layer_type}_" + # Second test to translate. # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True. # But in that case the following condition is a way to restore the original cache. @@ -1192,15 +1216,26 @@ def dynamic_frequency_update(self, position_ids, device): (lambda x, y: y.clone()), [long_inv_freq, original_inv_freq], ) - self.inv_freq = inv_freq + setattr(self, f"{prefix}inv_freq", inv_freq) @wraps(rope_forward) - def wrapper(self, x, position_ids): + def wrapper(self, x, position_ids, layer_type=None): + if layer_type is None: + if "dynamic" in self.rope_type: + dynamic_frequency_update(self, position_ids, device=x.device) + elif self.rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device) + return rope_forward(self, x, position_ids) + if "dynamic" in self.rope_type: - dynamic_frequency_update(self, position_ids, device=x.device) + dynamic_frequency_update( + self, position_ids, device=x.device, layer_type=layer_type + ) elif self.rope_type == "longrope": - longrope_frequency_update(self, position_ids, device=x.device) - return rope_forward(self, x, position_ids) + longrope_frequency_update( + self, position_ids, device=x.device, layer_type=layer_type + ) + return rope_forward(self, x, position_ids, layer_type=layer_type) return wrapper @@ -1296,12 +1331,18 @@ class common_RotaryEmbedding(torch.nn.Module): # @torch.no_grad() # PATCHED: the decorator @patched_dynamic_rope_update - def forward(self, x, position_ids): + def forward(self, x, position_ids, layer_type=None): + if layer_type is not None: + # transformers>=5.0 + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + else: + # transformers<5.0 + inv_freq = self.inv_freq + attention_scaling = self.attention_scaling + inv_freq_expanded = ( - self.inv_freq[None, :, None] - .float() - .expand(position_ids.shape[0], -1, 1) - .to(x.device) + inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) ) position_ids_expanded = position_ids[:, None, :].float() @@ -1313,8 +1354,8 @@ def forward(self, x, position_ids): with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)