diff --git a/.github/workflows/models.yml b/.github/workflows/models.yml index bcb232ca..f252f637 100644 --- a/.github/workflows/models.yml +++ b/.github/workflows/models.yml @@ -61,5 +61,5 @@ jobs: run: python -m pip freeze - name: qwen2.5_vl_instruct - run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual + run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 05b0ad6e..00ee0232 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.3 +++++ +* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies * :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator * :pr:`323`: drops torch 2.8 on CI * :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 53420e89..3fb0c587 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -521,9 +521,9 @@ def test_qwen2_5_vl_vision_attention_iteration(self): @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") @requires_cuda() - def test_plug_packed_multi_head_attention_qwen25(self): + def test_plug_packed_multi_head_attention_qwen25_packed(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_versatile, + qwen_sdpa_attention_packed_versatile, ) inputs = ( @@ -563,12 +563,71 @@ def test_plug_packed_multi_head_attention_qwen25(self): ).to("cuda"), ) - results = qwen_sdpa_attention_versatile.verify( + results = qwen_sdpa_attention_packed_versatile.verify( + *inputs, scaling=0.5, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) + self.assertLess(results.diffs[0]["abs"], 0.01) + + results = qwen_sdpa_attention_packed_versatile.verify( + *inputs, scaling=0.11180339887498948, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) + self.assertLess(results.diffs[0]["abs"], 0.01) + + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + def test_plug_packed_multi_head_attention_qwen25_loopmha(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + qwen_sdpa_attention_loopmha_versatile, + ) + + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.tensor( + [ + 0, + 64, + 128, + 192, + 256, + 304, + 368, + 432, + 496, + 560, + 608, + 672, + 736, + 800, + 864, + 912, + 976, + 1040, + 1104, + 1168, + 1216, + 1232, + 1248, + 1264, + 1280, + 1292, + ], + dtype=torch.int64, + ), + ) + + results = qwen_sdpa_attention_loopmha_versatile.verify( *inputs, scaling=0.5, num_heads=16, dump_onnx_model=self.get_dump_file( - "test_plug_packed_multi_head_attention_qwen25.onnx" + "test_plug_packed_multi_head_attention_qwen25_loopmha.onnx" ), ) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) @@ -576,7 +635,7 @@ def test_plug_packed_multi_head_attention_qwen25(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) self.assertLess(results.diffs[0]["abs"], 0.01) - results = qwen_sdpa_attention_versatile.verify( + results = qwen_sdpa_attention_loopmha_versatile.verify( *inputs, scaling=0.11180339887498948, num_heads=16 ) self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index ce331bd3..edae7f28 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union import numpy as np from onnx import ( AttributeProto, @@ -388,7 +388,7 @@ def _make_model_proto( onx.opset_import.append(oh.make_opsetid("", onnx_opset_version())) opsets = {d.domain: d.version for d in onx.opset_import} add = {} - for node in nodes: + for node in self.enumerate_nodes(onx.graph.node): if node.domain and node.domain not in opsets and node.domain not in add: add[node.domain] = 1 onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()]) @@ -402,6 +402,15 @@ def _make_model_outputs( ) -> Tuple[List[NodeProto], List[ValueInfoProto]]: return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o] + def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]: + "Enumerates nodes recursively." + for node in nodes: + if node.op_type in {"Scan", "If", "Loop"}: + for att in node.attribute: + if att.type == AttributeProto.GRAPH: + yield from self.enumerate_nodes(att.g.node) + yield node + @classmethod def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]: """ diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index eb4b493c..5678868a 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -39,12 +39,13 @@ def LoopMHAAttention( query_3d = op.Reshape(query_transposed, to_3d_shape) value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape) key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape) + cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32) num_patches = op.Size(cu_seqlens) - 1 seq_axis = op.Constant(value_ints=[1]) seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32) attn_output = op.Slice(value_3d, [0], [0], seq_axis) - for i in range(num_patches): - i_1d = op.Reshape(i, [1]) + for i_patch in range(num_patches): + i_1d = op.Reshape(i_patch, [1]) i_plus_1_1d = i_1d + 1 start = op.Gather(cu_seqlens, i_1d, axis=0) end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0) @@ -62,6 +63,14 @@ def LoopMHAAttention( attn_output_4d = op.Reshape(attn_output, output_shape) return attn_output_4d + def _add_com_microsoft_opset(function_proto): + opsets = {d.domain: d.version for d in function_proto.opset_import} + if "com.microsoft" not in opsets: + d = function_proto.opset_import.add() + d.domain = "com.microsoft" + d.version = 1 + return function_proto + @onnxscript.script(opset=onnx_plugs_op) def PackedAttention( query, @@ -143,20 +152,35 @@ def qwen_sdpa_attention( return attn_output # not ideal - qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx( + qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx( + qwen_sdpa_attention, + lambda qs, *args, **kwargs: torch.empty( + (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), + dtype=qs.dtype, + device=qs.device, + ), + _add_com_microsoft_opset(PackedAttention.to_function_proto()), + n_inputs=4, + n_outputs=1, + kwargs=dict(scaling=0.11180339887498948, num_heads=16), + name="qwen_sdpa_attention_packed", + ) + PLUGS.append(qwen_sdpa_attention_packed_versatile) + + qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx( qwen_sdpa_attention, lambda qs, *args, **kwargs: torch.empty( (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), dtype=qs.dtype, device=qs.device, ), - PackedAttention.to_function_proto(), + _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()), n_inputs=4, n_outputs=1, kwargs=dict(scaling=0.11180339887498948, num_heads=16), - name="qwen_sdpa_attention", + name="qwen_sdpa_attention_loopmha", ) - PLUGS.append(qwen_sdpa_attention_versatile) + PLUGS.append(qwen_sdpa_attention_loopmha_versatile) class patched_Qwen2_5_VLForConditionalGeneration: _PATCHES_ = ["prepare_inputs_for_generation"] @@ -496,8 +520,8 @@ def forward( or attention_interface is patched_sdpa_attention_forward ) attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION() - if is_sdpa and attention_strategy == "PACKED": - attn_output = qwen_sdpa_attention_versatile( + if is_sdpa and attention_strategy in "PACKED": + attn_output = qwen_sdpa_attention_packed_versatile( query_states, key_states, value_states, @@ -530,37 +554,37 @@ def forward( version=1, ) elif is_sdpa and attention_strategy == "LOOPMHA": + attn_output = qwen_sdpa_attention_loopmha_versatile( + query_states, + key_states, + value_states, + cu_seqlens, + self.scaling, + self.num_heads, + ) - def _iteration(start_end, query_states, key_states, value_states): - return patched_Qwen2_5_VLVisionAttentionOneIteration.forward( - self, - start_end, - query_states, - key_states, - value_states, - scaling=self.scaling, - dropout=0.0 if not self.training else self.attention_dropout, - ) - - starts = cu_seqlens[:-1] - ends = cu_seqlens[1:] - # cu_seqlens = [0, 10, 14, 27] - # starts: [0, 10, 14] - # ends: [10, 14, 17] - # starts_ends: [[0, 10], [10, 14], [14, 27]] - starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1) - attn_outputs = [ - _iteration(start_end, query_states, key_states, value_states) - for start_end in starts_ends - ] - # attn_outputs = torch._higher_order_ops.while_loop( - # attn_outputs = torch.ops.higher_order.while_loop( - # (lambda it, starts_ends, *_args: it < starts_ends.shape[0]), - # _iteration, - # (torch.tensor(0), - # starts_ends, query_states, key_states, value_states), tuple(), - # ) - attn_output = torch.cat(attn_outputs, dim=1) + # to rewrite later with a for loop + # def _iteration(start_end, query_states, key_states, value_states): + # return patched_Qwen2_5_VLVisionAttentionOneIteration.forward( + # self, + # start_end, + # query_states, + # key_states, + # value_states, + # scaling=self.scaling, + # dropout=0.0 if not self.training else self.attention_dropout, + # ) + + # starts = cu_seqlens[:-1] + # ends = cu_seqlens[1:] + # torch._check(starts.shape[0] > 0) + # torch._check(ends.shape[0] > 0) + # starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1) + # attn_outputs = [ + # _iteration(start_end, query_states, key_states, value_states) + # for start_end in starts_ends + # ] + # attn_output = torch.cat(attn_outputs, dim=1) elif is_sdpa and attention_strategy == "BIGMASK": # make square mask indices = torch.arange(