From cd01099a529951155d5c6ea83806687c9950340f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 14:58:17 +0100 Subject: [PATCH 1/7] add attention 24 --- _unittests/ut_tasks/try_export.py | 23 ++- .../test_patch_transformers.py | 195 ++++++++++++------ .../patches/_patch_transformers_qwen2_5.py | 89 ++++++-- 3 files changed, 210 insertions(+), 97 deletions(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index e39a407e..47cb3788 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -148,7 +148,7 @@ def _config_reduction(config, task): elif device == "cuda" and dtype in ("float16", "bfloat16"): attention_options = ["PACKED", "BIGMASK"] else: - attention_options = ["LOOPMHA", "BIGMASK"] + attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"] # fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0] for attention in attention_options: @@ -180,7 +180,7 @@ def _config_reduction(config, task): exporter=exporter, verbose=1, save_ep=None if self.unit_test_going() else (fileep, 2**35), - target_opset=22, + target_opset=24 if attention == "LOOPMHA" else 22, optimize=True, onnx_plugs=PLUGS, ) @@ -207,17 +207,18 @@ def _config_reduction(config, task): print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}") model = onnx.load(filename, load_external_data=False) if attention == "PACKED": - self.assertIn( - "PackedMultiHeadAttention", {n.op_type for n in model.graph.node} - ) + self.assertIn("PackedMultiHeadAttention", str(model)) elif attention == "BIGMASK": - self.assertNotIn( - "PackedMultiHeadAttention", {n.op_type for n in model.graph.node} - ) + self.assertNotIn("PackedMultiHeadAttention", str(model)) + self.assertNotIn("MultiHeadAttention", str(model)) + self.assertNotIn("Loop", {n.op_type for n in model.graph.node}) elif attention == "LOOPMHA": - self.assertNotIn( - "PackedMultiHeadAttention", {n.op_type for n in model.graph.node} - ) + self.assertNotIn("PackedMultiHeadAttention", str(model)) + self.assertIn("MultiHeadAttention", str(model)) + self.assertIn("Loop", {n.op_type for n in model.graph.node}) + elif attention == "LOOPA24": + self.assertNotIn("PackedMultiHeadAttention", str(model)) + self.assertNotIn("MultiHeadAttention", str(model)) self.assertIn("Loop", {n.op_type for n in model.graph.node}) else: raise AssertionError(f"attention={attention!r} not expected") diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 0ff07656..c8fa3add 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -519,9 +519,43 @@ def test_qwen2_5_vl_vision_attention_iteration(self): ) self.clean_dump() + @classmethod + def _get_seqlen(cls) -> torch.Tensor: + return torch.tensor( + [ + 0, + 64, + 128, + 192, + 256, + 304, + 368, + 432, + 496, + 560, + 608, + 672, + 736, + 800, + 864, + 912, + 976, + 1040, + 1104, + 1168, + 1216, + 1232, + 1248, + 1264, + 1280, + 1292, + ], + dtype=torch.int64, + ) + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") @requires_cuda() - def test_plug_packed_multi_head_attention_qwen25_packed(self): + def test_plug_packed_multi_head_attention_qwen25_packed_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_packed_versatile, ) @@ -530,37 +564,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self): torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), - torch.tensor( - [ - 0, - 64, - 128, - 192, - 256, - 304, - 368, - 432, - 496, - 560, - 608, - 672, - 736, - 800, - 864, - 912, - 976, - 1040, - 1104, - 1168, - 1216, - 1232, - 1248, - 1264, - 1280, - 1292, - ], - dtype=torch.int64, - ).to("cuda"), + self._get_seqlen().to("cuda"), ) results = qwen_sdpa_attention_packed_versatile.verify( @@ -580,7 +584,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self): self.assertLess(results.diffs[0]["abs"], 0.01) @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") - def test_plug_packed_multi_head_attention_qwen25_loopmha(self): + def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_loopmha_versatile, ) @@ -589,46 +593,15 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self): torch.rand((1, 16, 1292, 80), dtype=torch.float16), torch.rand((1, 16, 1292, 80), dtype=torch.float16), torch.rand((1, 16, 1292, 80), dtype=torch.float16), - torch.tensor( - [ - 0, - 64, - 128, - 192, - 256, - 304, - 368, - 432, - 496, - 560, - 608, - 672, - 736, - 800, - 864, - 912, - 976, - 1040, - 1104, - 1168, - 1216, - 1232, - 1248, - 1264, - 1280, - 1292, - ], - dtype=torch.int64, - ), + self._get_seqlen(), ) results = qwen_sdpa_attention_loopmha_versatile.verify( *inputs, scaling=0.5, num_heads=16, - itype=onnx.TensorProto.FLOAT16, dump_onnx_model=self.get_dump_file( - "test_plug_packed_multi_head_attention_qwen25_loopmha.onnx" + "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx" ), ) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) @@ -637,13 +610,101 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self): self.assertLess(results.diffs[0]["abs"], 0.01) results = qwen_sdpa_attention_loopmha_versatile.verify( - *inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16 + *inputs, scaling=0.11180339887498948, num_heads=16 ) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) self.assertEqual(len(results.eager_outputs), len(results.diffs)) self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) self.assertLess(results.diffs[0]["abs"], 0.01) + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + qwen_sdpa_attention_loopmha_versatile, + ) + + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + self._get_seqlen(), + ) + + results = qwen_sdpa_attention_loopmha_versatile.verify( + *inputs, + scaling=0.5, + num_heads=16, + dump_onnx_model=self.get_dump_file( + "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx" + ), + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) + + results = qwen_sdpa_attention_loopmha_versatile.verify( + *inputs, scaling=0.11180339887498948, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) + + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + qwen_sdpa_attention_loopa24_versatile, + ) + + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + self._get_seqlen(), + ) + + results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) + + results = qwen_sdpa_attention_loopa24_versatile.verify( + *inputs, scaling=0.11180339887498948 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) + + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + def test_plug_packed_multi_head_attention_qwen25_loopa24_float32(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + qwen_sdpa_attention_loopa24_versatile, + ) + + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + self._get_seqlen(), + ) + + results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) + + results = qwen_sdpa_attention_loopa24_versatile.verify( + *inputs, scaling=0.11180339887498948 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index 6a024470..20acf68a 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -22,6 +22,7 @@ onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1) op = onnxscript.opset22 + op24 = onnxscript.onnx_opset.opset24 msft_op = onnxscript.values.Opset("com.microsoft", 1) @onnxscript.script(opset=onnx_plugs_op) @@ -32,7 +33,6 @@ def LoopMHAAttention( cu_seqlens, scaling: float = 0.11180339887498948, num_heads: int = 16, - itype: int = onnx.TensorProto.FLOAT, ): to_3d_shape = op.Constant(value_ints=[0, 0, -1]) query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3]) @@ -45,7 +45,7 @@ def LoopMHAAttention( seq_axis = op.Constant(value_ints=[1]) seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32) # attn_output = op.Slice(value_3d, [0], [0], seq_axis) - seq_attn = op.SequenceEmpty(dtype=itype) + seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT) for i_patch in range(num_patches): i_1d = op.Reshape(i_patch, [1]) i_plus_1_1d = i_1d + 1 @@ -55,11 +55,7 @@ def LoopMHAAttention( key_i = op.Slice(key_3d, start, end, seq_axis_int32) value_i = op.Slice(value_3d, start, end, seq_axis_int32) mha_output = msft_op.MultiHeadAttention( - query_i, - key_i, - value_i, - num_heads=num_heads, - scale=scaling, + query_i, key_i, value_i, num_heads=num_heads, scale=scaling ) # attn_output = op.Concat(attn_output, mha_output, axis=1) seq_attn = op.SequenceInsert(seq_attn, mha_output) @@ -67,6 +63,47 @@ def LoopMHAAttention( attn_output_4d = op.Reshape(attn_output, output_shape) return attn_output_4d + @onnxscript.script(opset=onnx_plugs_op) + def LoopAttention24( + query_states, + key_states, + value_states, + cu_seqlens, + scaling: float = 0.11180339887498948, + num_heads: int = 16, + ): + to_3d_shape = op24.Constant(value_ints=[0, 0, -1]) + query_transposed = op24.Transpose(query_states, perm=[0, 2, 1, 3]) + output_shape = op24.Shape(query_transposed) + query_3d = op24.Reshape(query_transposed, to_3d_shape) + value_3d = op24.Reshape(op24.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape) + key_3d = op24.Reshape(op24.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape) + cu_seqlens = op24.Cast(cu_seqlens, to=onnx.TensorProto.INT32) + num_patches = op24.Size(cu_seqlens) - 1 + seq_axis = op24.Constant(value_ints=[1]) + seq_axis_int32 = op24.Cast(seq_axis, to=onnx.TensorProto.INT32) + seq_attn = op24.SequenceEmpty(dtype=onnx.TensorProto.FLOAT) + for i_patch in range(num_patches): + i_1d = op24.Reshape(i_patch, [1]) + i_plus_1_1d = i_1d + 1 + start = op24.Gather(cu_seqlens, i_1d, axis=0) + end = op24.Gather(cu_seqlens, i_plus_1_1d, axis=0) + query_i = op24.Slice(query_3d, start, end, seq_axis_int32) + key_i = op24.Slice(key_3d, start, end, seq_axis_int32) + value_i = op24.Slice(value_3d, start, end, seq_axis_int32) + mha_output = op24.Attention( + query_i, + key_i, + value_i, + scale=scaling, + q_num_heads=num_heads, + kv_num_heads=num_heads, + ) + seq_attn = op24.SequenceInsert(seq_attn, mha_output) + attn_output = op24.ConcatFromSequence(seq_attn, axis=1) + attn_output_4d = op24.Reshape(attn_output, output_shape) + return attn_output_4d + def _add_com_microsoft_opset(function_proto): opsets = {d.domain: d.version for d in function_proto.opset_import} if "com.microsoft" not in opsets: @@ -132,7 +169,6 @@ def qwen_sdpa_attention( cu_seqlens: torch.Tensor, # F7su19 scaling: float = 0, num_heads: int = 16, - itype: int = onnx.TensorProto.FLOAT, ) -> torch.Tensor: lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ @@ -167,7 +203,7 @@ def qwen_sdpa_attention( _add_com_microsoft_opset(PackedAttention.to_function_proto()), n_inputs=4, n_outputs=1, - kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT), + kwargs=dict(scaling=0.11180339887498948, num_heads=16), name="qwen_sdpa_attention_packed", ) PLUGS.append(qwen_sdpa_attention_packed_versatile) @@ -182,11 +218,26 @@ def qwen_sdpa_attention( _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), n_inputs=4, n_outputs=1, - kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT), + kwargs=dict(scaling=0.11180339887498948, num_heads=16), name="qwen_sdpa_attention_loopmha", ) PLUGS.append(qwen_sdpa_attention_loopmha_versatile) + qwen_sdpa_attention_loopa24_versatile = EagerDirectReplacementWithOnnx( + qwen_sdpa_attention, + lambda qs, *args, **kwargs: torch.empty( + (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), + dtype=qs.dtype, + device=qs.device, + ), + LoopAttention24.to_function_proto(), + n_inputs=4, + n_outputs=1, + kwargs=dict(scaling=0.11180339887498948, num_heads=16), + name="qwen_sdpa_attention_loopa24", + ) + PLUGS.append(qwen_sdpa_attention_loopa24_versatile) + class patched_Qwen2_5_VLForConditionalGeneration: _PATCHES_ = ["prepare_inputs_for_generation"] _PATCHED_CLASS_ = ( @@ -558,6 +609,15 @@ def forward( ), version=1, ) + elif is_sdpa and attention_strategy == "LOOPA24": + attn_output = qwen_sdpa_attention_loopa24_versatile( + query_states, + key_states, + value_states, + cu_seqlens, + self.scaling, + self.num_heads, + ) elif is_sdpa and attention_strategy == "LOOPMHA": attn_output = qwen_sdpa_attention_loopmha_versatile( query_states, @@ -566,15 +626,6 @@ def forward( cu_seqlens, self.scaling, self.num_heads, - ( - onnx.TensorProto.FLOAT - if query_states.dtype == torch.float32 - else ( - onnx.TensorProto.FLOAT16 - if query_states.dtype == torch.float16 - else onnx.TensorProto.BFLOAT16 - ) - ), ) # to rewrite later with a for loop From 51b8db5883dc3a76b0b26924cbda3c85b095ed6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 16:32:51 +0100 Subject: [PATCH 2/7] implements version plugs --- _unittests/ut_tasks/try_export.py | 7 +- .../test_patch_transformers.py | 8 +- onnx_diagnostic/export/onnx_plug.py | 178 ++++++++++++------ .../patches/_patch_transformers_qwen2_5.py | 60 +++++- 4 files changed, 179 insertions(+), 74 deletions(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index 47cb3788..36cc3c7f 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -17,6 +17,11 @@ class TestTryExportHuggingFaceHubModel(ExtTestCase): @ignore_warnings(UserWarning) def test_qwen25_vli_visual(self): """ + unittest:: + + UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python \\ + _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual + # task: imagetext2text clear&&NEVERTEST=1 python _unittests/ut_tasks/try_export.py -k qwen_2_5 @@ -180,7 +185,7 @@ def _config_reduction(config, task): exporter=exporter, verbose=1, save_ep=None if self.unit_test_going() else (fileep, 2**35), - target_opset=24 if attention == "LOOPMHA" else 22, + target_opset=24 if attention == "LOOPA24" else 22, optimize=True, onnx_plugs=PLUGS, ) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index c8fa3add..2d85e5f7 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -667,16 +667,16 @@ def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self): results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) - self.assertLess(results.diffs[0]["abs"], 1e-5) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2) + self.assertLess(results.diffs[0]["abs"], 1e-2) results = qwen_sdpa_attention_loopa24_versatile.verify( *inputs, scaling=0.11180339887498948 ) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) - self.assertLess(results.diffs[0]["abs"], 1e-5) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005) + self.assertLess(results.diffs[0]["abs"], 0.005) @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_packed_multi_head_attention_qwen25_loopa24_float32(self): diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 02c08edf..cd4614f5 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import onnx import torch -from ..helpers import max_diff +from ..helpers import max_diff, string_type from ..helpers.torch_helper import torch_dtype_to_onnx_dtype from ..reference import OnnxruntimeEvaluator @@ -50,6 +50,7 @@ class EagerDirectReplacementWithOnnx: only tensors must be counted :param name: the name of the custom op, the function name if not specified :param kwargs: constants parameters with their default values + :param version_selector: selects the version based on the arguments :param verbose: verbose level Here is an example: @@ -139,21 +140,28 @@ def __init__( self, eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS], shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS], - function_proto: onnx.FunctionProto, + function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]], n_inputs: Optional[int] = None, n_outputs: Optional[int] = None, name: Optional[str] = None, kwargs: Optional[Dict[str, Union[int, float]]] = None, verbose: int = 0, + version_selector: Optional[Callable[[Any], Any]] = None, ): - assert isinstance( - function_proto, onnx.FunctionProto + assert isinstance(function_proto, onnx.FunctionProto) or ( + isinstance(function_proto, dict) + or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values()) ), f"Unexpected type {type(function_proto)} for function_proto" assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}" - assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}" + assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}" self.eager_fn = eager_fn self.shape_fn = shape_fn - self.function_proto = function_proto + self._function_proto = ( + function_proto if isinstance(function_proto, onnx.FunctionProto) else None + ) + self._function_proto_versioned = ( + function_proto if isinstance(function_proto, dict) else {} + ) self.n_inputs = n_inputs self.n_outputs = n_outputs self.name = name or ( @@ -170,24 +178,72 @@ def __init__( ) sig = inspect.signature(self.eager_fn) params = list(sig.parameters) - assert ( - len(params) >= n_inputs - ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}" - assert n_inputs == len(function_proto.input), ( - f"Input mismatch n_inputs={n_inputs} but " - f"function_proto.input={function_proto.input}" - ) - assert n_outputs == len(function_proto.output), ( - f"Output mismatch n_outputs={n_outputs} but " - f"function_proto.output={function_proto.output}" - ) - assert ( - function_proto.domain == self.domain - ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}" self.args_name = [p for p in params if p not in self.kwargs] self.kwargs_name = [p for p in params if p in self.kwargs] self.verbose = verbose self.custom_op = self._register() + self.version_selector = version_selector + self._check_protos(params) + + def _check_protos(self, params): + assert ( + len(params) >= self.n_inputs + ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}" + + # one proto + assert self._function_proto is None or self.n_inputs == len( + self._function_proto.input + ), ( + f"Input mismatch n_inputs={self.n_inputs} but " + f"function_proto.input={self._function_proto.input}" + ) + assert self._function_proto is None or self.n_outputs == len( + self._function_proto.output + ), ( + f"Output mismatch n_outputs={self.n_outputs} but " + f"function_proto.output={self._function_proto.output}" + ) + assert self._function_proto is None or ( + self._function_proto.domain == self.domain + ), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}" + + # multiple protos + assert all( + self.n_inputs == len(v.input) for v in self._function_proto_versioned.values() + ), f"Output mismatch n_inputs={self.n_inputs} but one verion is wrong" + assert all( + self.n_outputs == len(v.output) for v in self._function_proto_versioned.values() + ), f"Output mismatch n_outputs={self.n_outputs} but one verion is wrong" + assert all( + v.domain == self.domain for v in self._function_proto_versioned.values() + ), f"Function domain must be {self.domain!r} but it is different in one version" + assert ( + not self._function_proto_versioned or self.version_selector + ), "version_selector is needed when multiple protos are given." + + def get_function_proto(self, *args) -> onnx.FunctionProto: + """Returns the correct version based on the inputs.""" + if self._function_proto: + return self._function_proto + if ( + len(args) == 1 + and isinstance(args[0], (int, str)) + and args[0] in self._function_proto_versioned + ): + return self._function_proto_versioned[args[0]] + try: + key = self.version_selector(*args) + except (ValueError, AttributeError) as e: + raise AssertionError( + f"Unable to select a version, fails to get a key, available=" + f"{set(self._function_proto_versioned)}, " + f"args={string_type(args,with_shape=True)}" + ) from e + assert key in self._function_proto_versioned, ( + f"Unable to select a version, key={key}, available=" + f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}" + ) + return self._function_proto_versioned[key] @property def domain(self) -> str: @@ -291,7 +347,7 @@ def verify( assert engine is None, f"Not implemented yet with engine={engine!r}" ags, kws = self._make_args_kwargs(*args, **kwargs) sess = OnnxruntimeEvaluator( - self.function_proto, + self.get_function_proto(*args), whole=True, dump_onnx_model=dump_onnx_model, function_kwargs=kws, @@ -324,16 +380,15 @@ def converter( *args, **kwargs, ) -> Any: - if not g.has_local_function( - self.function_proto.name, domain=self.function_proto.domain - ): - g.add_function(self.function_proto) + function_proto = self.get_function_proto(g.get_type(args[0])) + if not g.has_local_function(function_proto.name, domain=function_proto.domain): + g.add_function(function_proto) ags, kws = self._make_args_kwargs(*args, **kwargs) res = g.make_node( - self.function_proto.name, + function_proto.name, ags, outputs, - domain=self.function_proto.domain, + domain=function_proto.domain, name=self.target_name, **kws, ) @@ -356,41 +411,46 @@ def onnx_dynamo_converter(self) -> Callable: """ import onnxscript - onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1) - schema = onnx_plug_op[self.function_proto.name] - if schema is None: - all_types = [ - "tensor(float)", - "tensor(float16)", - "tensor(bfloat16)", - "tensor(double)", - "tensor(int64)", - "tensor(int32)", - ] - type_constraints = [] - for i in range(self.n_inputs): - type_constraints.append((f"T{i}", all_types, "")) - for i in range(self.n_outputs): - type_constraints.append((f"U{i}", all_types, "")) - schema = onnx.defs.OpSchema( - self.function_proto.name, - self.function_proto.domain, - 1, - inputs=[ - onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}") - for i in range(self.n_inputs) - ], - outputs=[ - onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}") - for i in range(self.n_outputs) - ], - type_constraints=type_constraints, - ) - onnx.defs.register_schema(schema) - op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema) + onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1) + + def get_proto(*args): + function_proto = self.get_function_proto() + schema = onnx_plug_op[function_proto.name] + if schema is None: + all_types = [ + "tensor(float)", + "tensor(float16)", + "tensor(bfloat16)", + "tensor(double)", + "tensor(int64)", + "tensor(int32)", + ] + type_constraints = [] + for i in range(self.n_inputs): + type_constraints.append((f"T{i}", all_types, "")) + for i in range(self.n_outputs): + type_constraints.append((f"U{i}", all_types, "")) + schema = onnx.defs.OpSchema( + function_proto.name, + function_proto.domain, + 1, + inputs=[ + onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}") + for i in range(self.n_inputs) + ], + outputs=[ + onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}") + for i in range(self.n_outputs) + ], + type_constraints=type_constraints, + ) + onnx.defs.register_schema(schema) + op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema) + return op def converter(*cargs, **ckwargs): ags, kws = self._make_args_kwargs(*cargs, **ckwargs) + op = get_proto(*cargs) return op(*ags, n_outputs=self.n_outputs, **kws) return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index 20acf68a..dfa5698e 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -1,8 +1,10 @@ import os from typing import Callable, Optional import onnx +import onnx.helper as oh import torch import torch.nn.functional as F +from ...helpers.torch_helper import torch_dtype_to_onnx_dtype from ...export.onnx_plug import EagerDirectReplacementWithOnnx from .patch_helper import _is_torchdynamo_exporting from ._patch_transformers_attention import patched_sdpa_attention_forward @@ -25,6 +27,36 @@ op24 = onnxscript.onnx_opset.opset24 msft_op = onnxscript.values.Opset("com.microsoft", 1) + def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto: + opsets = {d.domain: d.version for d in function_proto.opset_import} + if "com.microsoft" not in opsets: + d = function_proto.opset_import.add() + d.domain = "com.microsoft" + d.version = 1 + return function_proto + + def _update_sequence_type( + itype: int, function_proto: onnx.FunctionProto + ) -> onnx.FunctionProto: + proto = oh.make_function( + function_proto.domain, + function_proto.name, + function_proto.input, + function_proto.output, + [ + ( + oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype) + if node.op_type == "SequenceEmpty" + else node + ) + for node in function_proto.node + ], + attributes=function_proto.attribute, + attribute_protos=function_proto.attribute_proto, + opset_imports=function_proto.opset_import, + ) + return proto + @onnxscript.script(opset=onnx_plugs_op) def LoopMHAAttention( query_states, @@ -98,20 +130,13 @@ def LoopAttention24( scale=scaling, q_num_heads=num_heads, kv_num_heads=num_heads, + softmax_precision=onnx.TensorProto.FLOAT, ) seq_attn = op24.SequenceInsert(seq_attn, mha_output) attn_output = op24.ConcatFromSequence(seq_attn, axis=1) attn_output_4d = op24.Reshape(attn_output, output_shape) return attn_output_4d - def _add_com_microsoft_opset(function_proto): - opsets = {d.domain: d.version for d in function_proto.opset_import} - if "com.microsoft" not in opsets: - d = function_proto.opset_import.add() - d.domain = "com.microsoft" - d.version = 1 - return function_proto - @onnxscript.script(opset=onnx_plugs_op) def PackedAttention( query, @@ -215,11 +240,20 @@ def qwen_sdpa_attention( dtype=qs.dtype, device=qs.device, ), - _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), + { + onnx.TensorProto.FLOAT: _add_com_microsoft_opset( + LoopMHAAttention.to_function_proto() + ), + onnx.TensorProto.FLOAT16: _update_sequence_type( + onnx.TensorProto.FLOAT16, + _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), + ), + }, n_inputs=4, n_outputs=1, kwargs=dict(scaling=0.11180339887498948, num_heads=16), name="qwen_sdpa_attention_loopmha", + version_selector=lambda *args: torch_dtype_to_onnx_dtype(args[0].dtype), ) PLUGS.append(qwen_sdpa_attention_loopmha_versatile) @@ -230,11 +264,17 @@ def qwen_sdpa_attention( dtype=qs.dtype, device=qs.device, ), - LoopAttention24.to_function_proto(), + { + onnx.TensorProto.FLOAT: LoopAttention24.to_function_proto(), + onnx.TensorProto.FLOAT16: _update_sequence_type( + onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto() + ), + }, n_inputs=4, n_outputs=1, kwargs=dict(scaling=0.11180339887498948, num_heads=16), name="qwen_sdpa_attention_loopa24", + version_selector=lambda *args: torch_dtype_to_onnx_dtype(args[0].dtype), ) PLUGS.append(qwen_sdpa_attention_loopa24_versatile) From 2ab6859e0f51af58140e5bffb7b587856912e991 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 16:55:23 +0100 Subject: [PATCH 3/7] fix plugs --- CHANGELOGS.rst | 2 ++ onnx_diagnostic/export/api.py | 8 +++++++- onnx_diagnostic/export/onnx_plug.py | 4 ++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index ddc855de..ee0b1a3e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.8.4 +++++ +* :pr:`336`: implements versioned onnx plugs + 0.8.3 +++++ diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index b5a651cd..260a5d6b 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -149,6 +149,7 @@ def find_method(self, name: Any): if exporter in ("dynamo", "onnx-dynamo"): import os + from ..helpers import flatten_object import onnxscript.rewriter.ort_fusions as ort_fusions assert ( @@ -180,7 +181,12 @@ def find_method(self, name: Any): import onnx_ir as ir import onnx_ir.passes.common as common_passes - irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs] + irfunctions = [ + ir.from_proto( + plug.get_function_proto(*flatten_object((args, kwargs), drop_keys=True)) + ) + for plug in onnx_plugs + ] for func in irfunctions: epo.model.functions[func.identifier()] = func if inline: diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index cd4614f5..a4a2afb9 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -210,10 +210,10 @@ def _check_protos(self, params): # multiple protos assert all( self.n_inputs == len(v.input) for v in self._function_proto_versioned.values() - ), f"Output mismatch n_inputs={self.n_inputs} but one verion is wrong" + ), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong" assert all( self.n_outputs == len(v.output) for v in self._function_proto_versioned.values() - ), f"Output mismatch n_outputs={self.n_outputs} but one verion is wrong" + ), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong" assert all( v.domain == self.domain for v in self._function_proto_versioned.values() ), f"Function domain must be {self.domain!r} but it is different in one version" From b34d94f602f62ea48f7b8ef15edd8dab4ccbef67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 18:02:33 +0100 Subject: [PATCH 4/7] fix plugs --- _unittests/ut_tasks/try_export.py | 14 +++++++------- _unittests/ut_torch_onnx/test_sbs.py | 16 +--------------- onnx_diagnostic/export/onnx_plug.py | 2 +- 3 files changed, 9 insertions(+), 23 deletions(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index 36cc3c7f..146d4c4b 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -212,18 +212,18 @@ def _config_reduction(config, task): print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}") model = onnx.load(filename, load_external_data=False) if attention == "PACKED": - self.assertIn("PackedMultiHeadAttention", str(model)) + self.assertIn('"PackedMultiHeadAttention"', str(model)) elif attention == "BIGMASK": - self.assertNotIn("PackedMultiHeadAttention", str(model)) + self.assertNotIn('"PackedMultiHeadAttention"', str(model)) self.assertNotIn("MultiHeadAttention", str(model)) self.assertNotIn("Loop", {n.op_type for n in model.graph.node}) elif attention == "LOOPMHA": - self.assertNotIn("PackedMultiHeadAttention", str(model)) - self.assertIn("MultiHeadAttention", str(model)) + self.assertNotIn('"PackedMultiHeadAttention"', str(model)) + self.assertIn('"MultiHeadAttention"', str(model)) self.assertIn("Loop", {n.op_type for n in model.graph.node}) elif attention == "LOOPA24": - self.assertNotIn("PackedMultiHeadAttention", str(model)) - self.assertNotIn("MultiHeadAttention", str(model)) + self.assertNotIn('"PackedMultiHeadAttention"', str(model)) + self.assertNotIn('"MultiHeadAttention"', str(model)) self.assertIn("Loop", {n.op_type for n in model.graph.node}) else: raise AssertionError(f"attention={attention!r} not expected") @@ -257,7 +257,7 @@ def _config_reduction(config, task): else ["CPUExecutionProvider"] ), use_ort=True, - atol=0.02, + atol=0.05, rtol=10, # ep=pt2_file, expected=expected, diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 98aecb63..c6381769 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -693,21 +693,7 @@ def forward(self, query, key, value, seq_lens): ks = key * mask vs = value * mask attn_output = qwen_sdpa_attention_loopmha_versatile( - qs, - ks, - vs, - seq_lens, - 0.11, - 16, - ( - onnx.TensorProto.FLOAT - if query.dtype == torch.float32 - else ( - onnx.TensorProto.FLOAT16 - if query.dtype == torch.float16 - else onnx.TensorProto.BFLOAT16 - ) - ), + qs, ks, vs, seq_lens, 0.11, 16 ) red = attn_output.mean(dim=-1, keepdim=True) return attn_output - red diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index a4a2afb9..cfc65cb2 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -232,7 +232,7 @@ def get_function_proto(self, *args) -> onnx.FunctionProto: ): return self._function_proto_versioned[args[0]] try: - key = self.version_selector(*args) + key = self.version_selector(*args) # type: ignore[misc] except (ValueError, AttributeError) as e: raise AssertionError( f"Unable to select a version, fails to get a key, available=" From 1cf41ae6491b012ca28c7552b1365b5b3558cdde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 19:17:28 +0100 Subject: [PATCH 5/7] require --- .../test_patch_transformers.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 2d85e5f7..6a53d07c 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -7,6 +7,7 @@ from onnx_diagnostic.ext_test_case import ( ExtTestCase, requires_cuda, + requires_onnxruntime, requires_transformers, requires_torch, ignore_warnings, @@ -555,7 +556,7 @@ def _get_seqlen(cls) -> torch.Tensor: @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") @requires_cuda() - def test_plug_packed_multi_head_attention_qwen25_packed_float16(self): + def test_plug_multi_head_attention_qwen25_packed_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_packed_versatile, ) @@ -583,8 +584,9 @@ def test_plug_packed_multi_head_attention_qwen25_packed_float16(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) self.assertLess(results.diffs[0]["abs"], 0.01) + @requires_onnxruntime("1.24") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") - def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self): + def test_plug_multi_head_attention_qwen25_loopmha_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_loopmha_versatile, ) @@ -617,8 +619,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) self.assertLess(results.diffs[0]["abs"], 0.01) + @requires_onnxruntime("1.24") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") - def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self): + def test_plug_multi_head_attention_qwen25_loopmha_float32(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_loopmha_versatile, ) @@ -651,8 +654,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) self.assertLess(results.diffs[0]["abs"], 1e-5) + @requires_onnxruntime("1.24") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") - def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self): + def test_plug_multi_head_attention_qwen25_loopa24_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_loopa24_versatile, ) @@ -678,8 +682,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005) self.assertLess(results.diffs[0]["abs"], 0.005) + @requires_onnxruntime("1.24") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") - def test_plug_packed_multi_head_attention_qwen25_loopa24_float32(self): + def test_plug_multi_head_attention_qwen25_loopa24_float32(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( qwen_sdpa_attention_loopa24_versatile, ) From 41f4ebcca0980e3357ff67ecc27694f8ca986aed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 19:26:41 +0100 Subject: [PATCH 6/7] fix --- _unittests/ut_tasks/try_export.py | 9 ++++++++- onnx_diagnostic/ext_test_case.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index 146d4c4b..0e386ada 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -4,7 +4,12 @@ import onnx import textwrap import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, never_test, ignore_warnings +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + never_test, + ignore_warnings, + has_onnxruntime, +) from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.export.api import to_onnx @@ -157,6 +162,8 @@ def _config_reduction(config, task): # fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0] for attention in attention_options: + if attention == "LOOPA24" and not has_onnxruntime("1.24"): + continue with self.subTest(attention=attention): print() print(f"-- attention={attention!r}") diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index bf2e09cd..6c19b409 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -610,6 +610,21 @@ def requires_onnxruntime(version: str, msg: str = "") -> Callable: return lambda x: x +def has_onnxruntime(version: str, msg: str = "") -> Callable: + """Skips a unit test if :epkg:`onnxruntime` is not recent enough.""" + import packaging.version as pv + import onnxruntime + + if not hasattr(onnxruntime, "__version__"): + # development version + return True + + if pv.Version(onnxruntime.__version__) < pv.Version(version): + msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}" + return False + return True + + def has_onnxruntime_training(push_back_batch: bool = False): """Tells if onnxruntime_training is installed.""" try: From bbf563cacd1e594e4f473eafba5d451bd9213457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 09:08:21 +0100 Subject: [PATCH 7/7] fix test --- _unittests/ut_tasks/try_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index 0e386ada..d248f9dc 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -222,7 +222,7 @@ def _config_reduction(config, task): self.assertIn('"PackedMultiHeadAttention"', str(model)) elif attention == "BIGMASK": self.assertNotIn('"PackedMultiHeadAttention"', str(model)) - self.assertNotIn("MultiHeadAttention", str(model)) + self.assertIn("MultiHeadAttention", str(model)) self.assertNotIn("Loop", {n.op_type for n in model.graph.node}) elif attention == "LOOPMHA": self.assertNotIn('"PackedMultiHeadAttention"', str(model))