Skip to content

Commit 593a67d

Browse files
committed
Improves LoopMHA rewriting
1 parent 45f8fec commit 593a67d

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
626626
*inputs,
627627
scaling=0.5,
628628
num_heads=16,
629+
itype=onnx.TensorProto.FLOAT16,
629630
dump_onnx_model=self.get_dump_file(
630631
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
631632
),
@@ -636,7 +637,7 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
636637
self.assertLess(results.diffs[0]["abs"], 0.01)
637638

638639
results = qwen_sdpa_attention_loopmha_versatile.verify(
639-
*inputs, scaling=0.11180339887498948, num_heads=16
640+
*inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16
640641
)
641642
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
642643
self.assertEqual(len(results.eager_outputs), len(results.diffs))

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def LoopMHAAttention(
3232
cu_seqlens,
3333
scaling: float = 0.11180339887498948,
3434
num_heads: int = 16,
35+
itype: int = onnx.TensorProto.FLOAT,
3536
):
3637
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
3738
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
@@ -44,10 +45,7 @@ def LoopMHAAttention(
4445
seq_axis = op.Constant(value_ints=[1])
4546
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
4647
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
47-
# SequenceEmpty needs dtype to be filled but it should be possible
48-
# to leave it empty and just ensure that all tensors stored
49-
# in the sequence share the same type.
50-
seq_attn = op.SequenceEmpty() # dtype=onnx.TensorProto.FLOAT16)
48+
seq_attn = op.SequenceEmpty(dtype=itype)
5149
for i_patch in range(num_patches):
5250
i_1d = op.Reshape(i_patch, [1])
5351
i_plus_1_1d = i_1d + 1
@@ -134,6 +132,7 @@ def qwen_sdpa_attention(
134132
cu_seqlens: torch.Tensor, # F7su19
135133
scaling: float = 0,
136134
num_heads: int = 16,
135+
itype: int = onnx.TensorProto.FLOAT,
137136
) -> torch.Tensor:
138137
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
139138
splits = [
@@ -168,7 +167,7 @@ def qwen_sdpa_attention(
168167
_add_com_microsoft_opset(PackedAttention.to_function_proto()),
169168
n_inputs=4,
170169
n_outputs=1,
171-
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
170+
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
172171
name="qwen_sdpa_attention_packed",
173172
)
174173
PLUGS.append(qwen_sdpa_attention_packed_versatile)
@@ -183,7 +182,7 @@ def qwen_sdpa_attention(
183182
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
184183
n_inputs=4,
185184
n_outputs=1,
186-
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
185+
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
187186
name="qwen_sdpa_attention_loopmha",
188187
)
189188
PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
@@ -567,6 +566,15 @@ def forward(
567566
cu_seqlens,
568567
self.scaling,
569568
self.num_heads,
569+
(
570+
onnx.TensorProto.FLOAT
571+
if query_states.dtype == torch.float32
572+
else (
573+
onnx.TensorProto.FLOAT16
574+
if query_states.dtype == torch.float16
575+
else onnx.TensorProto.BFLOAT16
576+
)
577+
),
570578
)
571579

572580
# to rewrite later with a for loop

0 commit comments

Comments
 (0)