Skip to content

Commit cd01099

Browse files
committed
add attention 24
1 parent 03a2587 commit cd01099

File tree

3 files changed

+210
-97
lines changed

3 files changed

+210
-97
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _config_reduction(config, task):
148148
elif device == "cuda" and dtype in ("float16", "bfloat16"):
149149
attention_options = ["PACKED", "BIGMASK"]
150150
else:
151-
attention_options = ["LOOPMHA", "BIGMASK"]
151+
attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"]
152152

153153
# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
154154
for attention in attention_options:
@@ -180,7 +180,7 @@ def _config_reduction(config, task):
180180
exporter=exporter,
181181
verbose=1,
182182
save_ep=None if self.unit_test_going() else (fileep, 2**35),
183-
target_opset=22,
183+
target_opset=24 if attention == "LOOPMHA" else 22,
184184
optimize=True,
185185
onnx_plugs=PLUGS,
186186
)
@@ -207,17 +207,18 @@ def _config_reduction(config, task):
207207
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
208208
model = onnx.load(filename, load_external_data=False)
209209
if attention == "PACKED":
210-
self.assertIn(
211-
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
212-
)
210+
self.assertIn("PackedMultiHeadAttention", str(model))
213211
elif attention == "BIGMASK":
214-
self.assertNotIn(
215-
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
216-
)
212+
self.assertNotIn("PackedMultiHeadAttention", str(model))
213+
self.assertNotIn("MultiHeadAttention", str(model))
214+
self.assertNotIn("Loop", {n.op_type for n in model.graph.node})
217215
elif attention == "LOOPMHA":
218-
self.assertNotIn(
219-
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
220-
)
216+
self.assertNotIn("PackedMultiHeadAttention", str(model))
217+
self.assertIn("MultiHeadAttention", str(model))
218+
self.assertIn("Loop", {n.op_type for n in model.graph.node})
219+
elif attention == "LOOPA24":
220+
self.assertNotIn("PackedMultiHeadAttention", str(model))
221+
self.assertNotIn("MultiHeadAttention", str(model))
221222
self.assertIn("Loop", {n.op_type for n in model.graph.node})
222223
else:
223224
raise AssertionError(f"attention={attention!r} not expected")

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 128 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -519,9 +519,43 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
519519
)
520520
self.clean_dump()
521521

522+
@classmethod
523+
def _get_seqlen(cls) -> torch.Tensor:
524+
return torch.tensor(
525+
[
526+
0,
527+
64,
528+
128,
529+
192,
530+
256,
531+
304,
532+
368,
533+
432,
534+
496,
535+
560,
536+
608,
537+
672,
538+
736,
539+
800,
540+
864,
541+
912,
542+
976,
543+
1040,
544+
1104,
545+
1168,
546+
1216,
547+
1232,
548+
1248,
549+
1264,
550+
1280,
551+
1292,
552+
],
553+
dtype=torch.int64,
554+
)
555+
522556
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
523557
@requires_cuda()
524-
def test_plug_packed_multi_head_attention_qwen25_packed(self):
558+
def test_plug_packed_multi_head_attention_qwen25_packed_float16(self):
525559
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
526560
qwen_sdpa_attention_packed_versatile,
527561
)
@@ -530,37 +564,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
530564
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
531565
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
532566
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
533-
torch.tensor(
534-
[
535-
0,
536-
64,
537-
128,
538-
192,
539-
256,
540-
304,
541-
368,
542-
432,
543-
496,
544-
560,
545-
608,
546-
672,
547-
736,
548-
800,
549-
864,
550-
912,
551-
976,
552-
1040,
553-
1104,
554-
1168,
555-
1216,
556-
1232,
557-
1248,
558-
1264,
559-
1280,
560-
1292,
561-
],
562-
dtype=torch.int64,
563-
).to("cuda"),
567+
self._get_seqlen().to("cuda"),
564568
)
565569

566570
results = qwen_sdpa_attention_packed_versatile.verify(
@@ -580,7 +584,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
580584
self.assertLess(results.diffs[0]["abs"], 0.01)
581585

582586
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
583-
def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
587+
def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self):
584588
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
585589
qwen_sdpa_attention_loopmha_versatile,
586590
)
@@ -589,46 +593,15 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
589593
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
590594
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
591595
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
592-
torch.tensor(
593-
[
594-
0,
595-
64,
596-
128,
597-
192,
598-
256,
599-
304,
600-
368,
601-
432,
602-
496,
603-
560,
604-
608,
605-
672,
606-
736,
607-
800,
608-
864,
609-
912,
610-
976,
611-
1040,
612-
1104,
613-
1168,
614-
1216,
615-
1232,
616-
1248,
617-
1264,
618-
1280,
619-
1292,
620-
],
621-
dtype=torch.int64,
622-
),
596+
self._get_seqlen(),
623597
)
624598

625599
results = qwen_sdpa_attention_loopmha_versatile.verify(
626600
*inputs,
627601
scaling=0.5,
628602
num_heads=16,
629-
itype=onnx.TensorProto.FLOAT16,
630603
dump_onnx_model=self.get_dump_file(
631-
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
604+
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
632605
),
633606
)
634607
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
@@ -637,13 +610,101 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
637610
self.assertLess(results.diffs[0]["abs"], 0.01)
638611

639612
results = qwen_sdpa_attention_loopmha_versatile.verify(
640-
*inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16
613+
*inputs, scaling=0.11180339887498948, num_heads=16
641614
)
642615
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
643616
self.assertEqual(len(results.eager_outputs), len(results.diffs))
644617
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
645618
self.assertLess(results.diffs[0]["abs"], 0.01)
646619

620+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
621+
def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self):
622+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
623+
qwen_sdpa_attention_loopmha_versatile,
624+
)
625+
626+
inputs = (
627+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
628+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
629+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
630+
self._get_seqlen(),
631+
)
632+
633+
results = qwen_sdpa_attention_loopmha_versatile.verify(
634+
*inputs,
635+
scaling=0.5,
636+
num_heads=16,
637+
dump_onnx_model=self.get_dump_file(
638+
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
639+
),
640+
)
641+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
642+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
643+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
644+
self.assertLess(results.diffs[0]["abs"], 1e-5)
645+
646+
results = qwen_sdpa_attention_loopmha_versatile.verify(
647+
*inputs, scaling=0.11180339887498948, num_heads=16
648+
)
649+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
650+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
651+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
652+
self.assertLess(results.diffs[0]["abs"], 1e-5)
653+
654+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
655+
def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self):
656+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
657+
qwen_sdpa_attention_loopa24_versatile,
658+
)
659+
660+
inputs = (
661+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
662+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
663+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
664+
self._get_seqlen(),
665+
)
666+
667+
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
668+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
669+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
670+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
671+
self.assertLess(results.diffs[0]["abs"], 1e-5)
672+
673+
results = qwen_sdpa_attention_loopa24_versatile.verify(
674+
*inputs, scaling=0.11180339887498948
675+
)
676+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
677+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
678+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
679+
self.assertLess(results.diffs[0]["abs"], 1e-5)
680+
681+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
682+
def test_plug_packed_multi_head_attention_qwen25_loopa24_float32(self):
683+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
684+
qwen_sdpa_attention_loopa24_versatile,
685+
)
686+
687+
inputs = (
688+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
689+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
690+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
691+
self._get_seqlen(),
692+
)
693+
694+
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
695+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
696+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
697+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
698+
self.assertLess(results.diffs[0]["abs"], 1e-5)
699+
700+
results = qwen_sdpa_attention_loopa24_versatile.verify(
701+
*inputs, scaling=0.11180339887498948
702+
)
703+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
704+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
705+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
706+
self.assertLess(results.diffs[0]["abs"], 1e-5)
707+
647708

648709
if __name__ == "__main__":
649710
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)