Skip to content

Commit 7e7c2e8

Browse files
authored
Improves rewriting of LoopMHA, use use ConcatFromSequence (#326)
* first try * Improves LoopMHA rewriting * doc * doc
1 parent 035782f commit 7e7c2e8

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
78
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
89
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
910
* :pr:`323`: drops torch 2.8 on CI

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
626626
*inputs,
627627
scaling=0.5,
628628
num_heads=16,
629+
itype=onnx.TensorProto.FLOAT16,
629630
dump_onnx_model=self.get_dump_file(
630631
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
631632
),
@@ -636,7 +637,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
636637
self.assertLess(results.diffs[0]["abs"], 0.01)
637638

638639
results = qwen_sdpa_attention_loopmha_versatile.verify(
639-
*inputs, scaling=0.11180339887498948, num_heads=16
640+
*inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16
640641
)
641642
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
642643
self.assertEqual(len(results.eager_outputs), len(results.diffs))

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def LoopMHAAttention(
3232
cu_seqlens,
3333
scaling: float = 0.11180339887498948,
3434
num_heads: int = 16,
35+
itype: int = onnx.TensorProto.FLOAT,
3536
):
3637
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
3738
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
@@ -43,7 +44,8 @@ def LoopMHAAttention(
4344
num_patches = op.Size(cu_seqlens) - 1
4445
seq_axis = op.Constant(value_ints=[1])
4546
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
46-
attn_output = op.Slice(value_3d, [0], [0], seq_axis)
47+
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
48+
seq_attn = op.SequenceEmpty(dtype=itype)
4749
for i_patch in range(num_patches):
4850
i_1d = op.Reshape(i_patch, [1])
4951
i_plus_1_1d = i_1d + 1
@@ -59,7 +61,9 @@ def LoopMHAAttention(
5961
num_heads=num_heads,
6062
scale=scaling,
6163
)
62-
attn_output = op.Concat(attn_output, mha_output, axis=1)
64+
# attn_output = op.Concat(attn_output, mha_output, axis=1)
65+
seq_attn = op.SequenceInsert(seq_attn, mha_output)
66+
attn_output = op.ConcatFromSequence(seq_attn, axis=1)
6367
attn_output_4d = op.Reshape(attn_output, output_shape)
6468
return attn_output_4d
6569

@@ -128,6 +132,7 @@ def qwen_sdpa_attention(
128132
cu_seqlens: torch.Tensor, # F7su19
129133
scaling: float = 0,
130134
num_heads: int = 16,
135+
itype: int = onnx.TensorProto.FLOAT,
131136
) -> torch.Tensor:
132137
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
133138
splits = [
@@ -162,7 +167,7 @@ def qwen_sdpa_attention(
162167
_add_com_microsoft_opset(PackedAttention.to_function_proto()),
163168
n_inputs=4,
164169
n_outputs=1,
165-
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
170+
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
166171
name="qwen_sdpa_attention_packed",
167172
)
168173
PLUGS.append(qwen_sdpa_attention_packed_versatile)
@@ -177,7 +182,7 @@ def qwen_sdpa_attention(
177182
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
178183
n_inputs=4,
179184
n_outputs=1,
180-
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
185+
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
181186
name="qwen_sdpa_attention_loopmha",
182187
)
183188
PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
@@ -561,6 +566,15 @@ def forward(
561566
cu_seqlens,
562567
self.scaling,
563568
self.num_heads,
569+
(
570+
onnx.TensorProto.FLOAT
571+
if query_states.dtype == torch.float32
572+
else (
573+
onnx.TensorProto.FLOAT16
574+
if query_states.dtype == torch.float16
575+
else onnx.TensorProto.BFLOAT16
576+
)
577+
),
564578
)
565579

566580
# to rewrite later with a for loop

0 commit comments

Comments
 (0)