Skip to content

Commit 035782f

Browse files
authored
Fixes onnx plug for LoopMHA (Qwen2.5) (#325)
* fix plug for loopMHA * mypy * fix * doc * doc * remove cuda
1 parent 5842d50 commit 035782f

File tree

5 files changed

+139
-46
lines changed

5 files changed

+139
-46
lines changed

.github/workflows/models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,5 @@ jobs:
6161
run: python -m pip freeze
6262

6363
- name: qwen2.5_vl_instruct
64-
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
64+
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
6565

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
78
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
89
* :pr:`323`: drops torch 2.8 on CI
910
* :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,9 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
521521

522522
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
523523
@requires_cuda()
524-
def test_plug_packed_multi_head_attention_qwen25(self):
524+
def test_plug_packed_multi_head_attention_qwen25_packed(self):
525525
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
526-
qwen_sdpa_attention_versatile,
526+
qwen_sdpa_attention_packed_versatile,
527527
)
528528

529529
inputs = (
@@ -563,20 +563,79 @@ def test_plug_packed_multi_head_attention_qwen25(self):
563563
).to("cuda"),
564564
)
565565

566-
results = qwen_sdpa_attention_versatile.verify(
566+
results = qwen_sdpa_attention_packed_versatile.verify(
567+
*inputs, scaling=0.5, num_heads=16
568+
)
569+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
570+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
571+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
572+
self.assertLess(results.diffs[0]["abs"], 0.01)
573+
574+
results = qwen_sdpa_attention_packed_versatile.verify(
575+
*inputs, scaling=0.11180339887498948, num_heads=16
576+
)
577+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
578+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
579+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
580+
self.assertLess(results.diffs[0]["abs"], 0.01)
581+
582+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
583+
def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
584+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
585+
qwen_sdpa_attention_loopmha_versatile,
586+
)
587+
588+
inputs = (
589+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
590+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
591+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
592+
torch.tensor(
593+
[
594+
0,
595+
64,
596+
128,
597+
192,
598+
256,
599+
304,
600+
368,
601+
432,
602+
496,
603+
560,
604+
608,
605+
672,
606+
736,
607+
800,
608+
864,
609+
912,
610+
976,
611+
1040,
612+
1104,
613+
1168,
614+
1216,
615+
1232,
616+
1248,
617+
1264,
618+
1280,
619+
1292,
620+
],
621+
dtype=torch.int64,
622+
),
623+
)
624+
625+
results = qwen_sdpa_attention_loopmha_versatile.verify(
567626
*inputs,
568627
scaling=0.5,
569628
num_heads=16,
570629
dump_onnx_model=self.get_dump_file(
571-
"test_plug_packed_multi_head_attention_qwen25.onnx"
630+
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
572631
),
573632
)
574633
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
575634
self.assertEqual(len(results.eager_outputs), len(results.diffs))
576635
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
577636
self.assertLess(results.diffs[0]["abs"], 0.01)
578637

579-
results = qwen_sdpa_attention_versatile.verify(
638+
results = qwen_sdpa_attention_loopmha_versatile.verify(
580639
*inputs, scaling=0.11180339887498948, num_heads=16
581640
)
582641
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
1+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
22
import numpy as np
33
from onnx import (
44
AttributeProto,
@@ -388,7 +388,7 @@ def _make_model_proto(
388388
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
389389
opsets = {d.domain: d.version for d in onx.opset_import}
390390
add = {}
391-
for node in nodes:
391+
for node in self.enumerate_nodes(onx.graph.node):
392392
if node.domain and node.domain not in opsets and node.domain not in add:
393393
add[node.domain] = 1
394394
onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()])
@@ -402,6 +402,15 @@ def _make_model_outputs(
402402
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
403403
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
404404

405+
def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
406+
"Enumerates nodes recursively."
407+
for node in nodes:
408+
if node.op_type in {"Scan", "If", "Loop"}:
409+
for att in node.attribute:
410+
if att.type == AttributeProto.GRAPH:
411+
yield from self.enumerate_nodes(att.g.node)
412+
yield node
413+
405414
@classmethod
406415
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
407416
"""

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 62 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ def LoopMHAAttention(
3939
query_3d = op.Reshape(query_transposed, to_3d_shape)
4040
value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
4141
key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
42+
cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
4243
num_patches = op.Size(cu_seqlens) - 1
4344
seq_axis = op.Constant(value_ints=[1])
4445
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
4546
attn_output = op.Slice(value_3d, [0], [0], seq_axis)
46-
for i in range(num_patches):
47-
i_1d = op.Reshape(i, [1])
47+
for i_patch in range(num_patches):
48+
i_1d = op.Reshape(i_patch, [1])
4849
i_plus_1_1d = i_1d + 1
4950
start = op.Gather(cu_seqlens, i_1d, axis=0)
5051
end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0)
@@ -62,6 +63,14 @@ def LoopMHAAttention(
6263
attn_output_4d = op.Reshape(attn_output, output_shape)
6364
return attn_output_4d
6465

66+
def _add_com_microsoft_opset(function_proto):
67+
opsets = {d.domain: d.version for d in function_proto.opset_import}
68+
if "com.microsoft" not in opsets:
69+
d = function_proto.opset_import.add()
70+
d.domain = "com.microsoft"
71+
d.version = 1
72+
return function_proto
73+
6574
@onnxscript.script(opset=onnx_plugs_op)
6675
def PackedAttention(
6776
query,
@@ -143,20 +152,35 @@ def qwen_sdpa_attention(
143152
return attn_output
144153

145154
# not ideal
146-
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
155+
qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx(
156+
qwen_sdpa_attention,
157+
lambda qs, *args, **kwargs: torch.empty(
158+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
159+
dtype=qs.dtype,
160+
device=qs.device,
161+
),
162+
_add_com_microsoft_opset(PackedAttention.to_function_proto()),
163+
n_inputs=4,
164+
n_outputs=1,
165+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
166+
name="qwen_sdpa_attention_packed",
167+
)
168+
PLUGS.append(qwen_sdpa_attention_packed_versatile)
169+
170+
qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx(
147171
qwen_sdpa_attention,
148172
lambda qs, *args, **kwargs: torch.empty(
149173
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
150174
dtype=qs.dtype,
151175
device=qs.device,
152176
),
153-
PackedAttention.to_function_proto(),
177+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
154178
n_inputs=4,
155179
n_outputs=1,
156180
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
157-
name="qwen_sdpa_attention",
181+
name="qwen_sdpa_attention_loopmha",
158182
)
159-
PLUGS.append(qwen_sdpa_attention_versatile)
183+
PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
160184

161185
class patched_Qwen2_5_VLForConditionalGeneration:
162186
_PATCHES_ = ["prepare_inputs_for_generation"]
@@ -496,8 +520,8 @@ def forward(
496520
or attention_interface is patched_sdpa_attention_forward
497521
)
498522
attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
499-
if is_sdpa and attention_strategy == "PACKED":
500-
attn_output = qwen_sdpa_attention_versatile(
523+
if is_sdpa and attention_strategy in "PACKED":
524+
attn_output = qwen_sdpa_attention_packed_versatile(
501525
query_states,
502526
key_states,
503527
value_states,
@@ -530,37 +554,37 @@ def forward(
530554
version=1,
531555
)
532556
elif is_sdpa and attention_strategy == "LOOPMHA":
557+
attn_output = qwen_sdpa_attention_loopmha_versatile(
558+
query_states,
559+
key_states,
560+
value_states,
561+
cu_seqlens,
562+
self.scaling,
563+
self.num_heads,
564+
)
533565

534-
def _iteration(start_end, query_states, key_states, value_states):
535-
return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
536-
self,
537-
start_end,
538-
query_states,
539-
key_states,
540-
value_states,
541-
scaling=self.scaling,
542-
dropout=0.0 if not self.training else self.attention_dropout,
543-
)
544-
545-
starts = cu_seqlens[:-1]
546-
ends = cu_seqlens[1:]
547-
# cu_seqlens = [0, 10, 14, 27]
548-
# starts: [0, 10, 14]
549-
# ends: [10, 14, 17]
550-
# starts_ends: [[0, 10], [10, 14], [14, 27]]
551-
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
552-
attn_outputs = [
553-
_iteration(start_end, query_states, key_states, value_states)
554-
for start_end in starts_ends
555-
]
556-
# attn_outputs = torch._higher_order_ops.while_loop(
557-
# attn_outputs = torch.ops.higher_order.while_loop(
558-
# (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
559-
# _iteration,
560-
# (torch.tensor(0),
561-
# starts_ends, query_states, key_states, value_states), tuple(),
562-
# )
563-
attn_output = torch.cat(attn_outputs, dim=1)
566+
# to rewrite later with a for loop
567+
# def _iteration(start_end, query_states, key_states, value_states):
568+
# return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
569+
# self,
570+
# start_end,
571+
# query_states,
572+
# key_states,
573+
# value_states,
574+
# scaling=self.scaling,
575+
# dropout=0.0 if not self.training else self.attention_dropout,
576+
# )
577+
578+
# starts = cu_seqlens[:-1]
579+
# ends = cu_seqlens[1:]
580+
# torch._check(starts.shape[0] > 0)
581+
# torch._check(ends.shape[0] > 0)
582+
# starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
583+
# attn_outputs = [
584+
# _iteration(start_end, query_states, key_states, value_states)
585+
# for start_end in starts_ends
586+
# ]
587+
# attn_output = torch.cat(attn_outputs, dim=1)
564588
elif is_sdpa and attention_strategy == "BIGMASK":
565589
# make square mask
566590
indices = torch.arange(

0 commit comments

Comments
 (0)