Skip to content

Commit d311955

Browse files
authored
Implements versioned onnx plugs (#336)
* add attention 24 * implements version plugs * fix plugs * fix plugs * require * fix * fix test
1 parent 9bb7013 commit d311955

File tree

8 files changed

+421
-182
lines changed

8 files changed

+421
-182
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.8.4
55
+++++
66

7+
* :pr:`336`: implements versioned onnx plugs
8+
79
0.8.3
810
+++++
911

_unittests/ut_tasks/try_export.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import onnx
55
import textwrap
66
import torch
7-
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test, ignore_warnings
7+
from onnx_diagnostic.ext_test_case import (
8+
ExtTestCase,
9+
never_test,
10+
ignore_warnings,
11+
has_onnxruntime,
12+
)
813
from onnx_diagnostic.torch_export_patches import torch_export_patches
914
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1015
from onnx_diagnostic.export.api import to_onnx
@@ -17,6 +22,11 @@ class TestTryExportHuggingFaceHubModel(ExtTestCase):
1722
@ignore_warnings(UserWarning)
1823
def test_qwen25_vli_visual(self):
1924
"""
25+
unittest::
26+
27+
UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python \\
28+
_unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
29+
2030
# task: imagetext2text
2131
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_export.py -k qwen_2_5
2232
@@ -148,10 +158,12 @@ def _config_reduction(config, task):
148158
elif device == "cuda" and dtype in ("float16", "bfloat16"):
149159
attention_options = ["PACKED", "BIGMASK"]
150160
else:
151-
attention_options = ["LOOPMHA", "BIGMASK"]
161+
attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"]
152162

153163
# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
154164
for attention in attention_options:
165+
if attention == "LOOPA24" and not has_onnxruntime("1.24"):
166+
continue
155167
with self.subTest(attention=attention):
156168
print()
157169
print(f"-- attention={attention!r}")
@@ -180,7 +192,7 @@ def _config_reduction(config, task):
180192
exporter=exporter,
181193
verbose=1,
182194
save_ep=None if self.unit_test_going() else (fileep, 2**35),
183-
target_opset=22,
195+
target_opset=24 if attention == "LOOPA24" else 22,
184196
optimize=True,
185197
onnx_plugs=PLUGS,
186198
)
@@ -207,17 +219,18 @@ def _config_reduction(config, task):
207219
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
208220
model = onnx.load(filename, load_external_data=False)
209221
if attention == "PACKED":
210-
self.assertIn(
211-
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
212-
)
222+
self.assertIn('"PackedMultiHeadAttention"', str(model))
213223
elif attention == "BIGMASK":
214-
self.assertNotIn(
215-
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
216-
)
224+
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
225+
self.assertIn("MultiHeadAttention", str(model))
226+
self.assertNotIn("Loop", {n.op_type for n in model.graph.node})
217227
elif attention == "LOOPMHA":
218-
self.assertNotIn(
219-
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
220-
)
228+
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
229+
self.assertIn('"MultiHeadAttention"', str(model))
230+
self.assertIn("Loop", {n.op_type for n in model.graph.node})
231+
elif attention == "LOOPA24":
232+
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
233+
self.assertNotIn('"MultiHeadAttention"', str(model))
221234
self.assertIn("Loop", {n.op_type for n in model.graph.node})
222235
else:
223236
raise AssertionError(f"attention={attention!r} not expected")
@@ -251,7 +264,7 @@ def _config_reduction(config, task):
251264
else ["CPUExecutionProvider"]
252265
),
253266
use_ort=True,
254-
atol=0.02,
267+
atol=0.05,
255268
rtol=10,
256269
# ep=pt2_file,
257270
expected=expected,

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 133 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from onnx_diagnostic.ext_test_case import (
88
ExtTestCase,
99
requires_cuda,
10+
requires_onnxruntime,
1011
requires_transformers,
1112
requires_torch,
1213
ignore_warnings,
@@ -519,9 +520,43 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
519520
)
520521
self.clean_dump()
521522

523+
@classmethod
524+
def _get_seqlen(cls) -> torch.Tensor:
525+
return torch.tensor(
526+
[
527+
0,
528+
64,
529+
128,
530+
192,
531+
256,
532+
304,
533+
368,
534+
432,
535+
496,
536+
560,
537+
608,
538+
672,
539+
736,
540+
800,
541+
864,
542+
912,
543+
976,
544+
1040,
545+
1104,
546+
1168,
547+
1216,
548+
1232,
549+
1248,
550+
1264,
551+
1280,
552+
1292,
553+
],
554+
dtype=torch.int64,
555+
)
556+
522557
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
523558
@requires_cuda()
524-
def test_plug_packed_multi_head_attention_qwen25_packed(self):
559+
def test_plug_multi_head_attention_qwen25_packed_float16(self):
525560
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
526561
qwen_sdpa_attention_packed_versatile,
527562
)
@@ -530,37 +565,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
530565
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
531566
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
532567
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
533-
torch.tensor(
534-
[
535-
0,
536-
64,
537-
128,
538-
192,
539-
256,
540-
304,
541-
368,
542-
432,
543-
496,
544-
560,
545-
608,
546-
672,
547-
736,
548-
800,
549-
864,
550-
912,
551-
976,
552-
1040,
553-
1104,
554-
1168,
555-
1216,
556-
1232,
557-
1248,
558-
1264,
559-
1280,
560-
1292,
561-
],
562-
dtype=torch.int64,
563-
).to("cuda"),
568+
self._get_seqlen().to("cuda"),
564569
)
565570

566571
results = qwen_sdpa_attention_packed_versatile.verify(
@@ -579,8 +584,9 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
579584
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
580585
self.assertLess(results.diffs[0]["abs"], 0.01)
581586

587+
@requires_onnxruntime("1.24")
582588
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
583-
def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
589+
def test_plug_multi_head_attention_qwen25_loopmha_float16(self):
584590
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
585591
qwen_sdpa_attention_loopmha_versatile,
586592
)
@@ -589,46 +595,15 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
589595
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
590596
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
591597
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
592-
torch.tensor(
593-
[
594-
0,
595-
64,
596-
128,
597-
192,
598-
256,
599-
304,
600-
368,
601-
432,
602-
496,
603-
560,
604-
608,
605-
672,
606-
736,
607-
800,
608-
864,
609-
912,
610-
976,
611-
1040,
612-
1104,
613-
1168,
614-
1216,
615-
1232,
616-
1248,
617-
1264,
618-
1280,
619-
1292,
620-
],
621-
dtype=torch.int64,
622-
),
598+
self._get_seqlen(),
623599
)
624600

625601
results = qwen_sdpa_attention_loopmha_versatile.verify(
626602
*inputs,
627603
scaling=0.5,
628604
num_heads=16,
629-
itype=onnx.TensorProto.FLOAT16,
630605
dump_onnx_model=self.get_dump_file(
631-
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
606+
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
632607
),
633608
)
634609
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
@@ -637,13 +612,104 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
637612
self.assertLess(results.diffs[0]["abs"], 0.01)
638613

639614
results = qwen_sdpa_attention_loopmha_versatile.verify(
640-
*inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16
615+
*inputs, scaling=0.11180339887498948, num_heads=16
641616
)
642617
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
643618
self.assertEqual(len(results.eager_outputs), len(results.diffs))
644619
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
645620
self.assertLess(results.diffs[0]["abs"], 0.01)
646621

622+
@requires_onnxruntime("1.24")
623+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
624+
def test_plug_multi_head_attention_qwen25_loopmha_float32(self):
625+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
626+
qwen_sdpa_attention_loopmha_versatile,
627+
)
628+
629+
inputs = (
630+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
631+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
632+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
633+
self._get_seqlen(),
634+
)
635+
636+
results = qwen_sdpa_attention_loopmha_versatile.verify(
637+
*inputs,
638+
scaling=0.5,
639+
num_heads=16,
640+
dump_onnx_model=self.get_dump_file(
641+
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
642+
),
643+
)
644+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
645+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
646+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
647+
self.assertLess(results.diffs[0]["abs"], 1e-5)
648+
649+
results = qwen_sdpa_attention_loopmha_versatile.verify(
650+
*inputs, scaling=0.11180339887498948, num_heads=16
651+
)
652+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
653+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
654+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
655+
self.assertLess(results.diffs[0]["abs"], 1e-5)
656+
657+
@requires_onnxruntime("1.24")
658+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
659+
def test_plug_multi_head_attention_qwen25_loopa24_float16(self):
660+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
661+
qwen_sdpa_attention_loopa24_versatile,
662+
)
663+
664+
inputs = (
665+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
666+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
667+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
668+
self._get_seqlen(),
669+
)
670+
671+
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
672+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
673+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
674+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2)
675+
self.assertLess(results.diffs[0]["abs"], 1e-2)
676+
677+
results = qwen_sdpa_attention_loopa24_versatile.verify(
678+
*inputs, scaling=0.11180339887498948
679+
)
680+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
681+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
682+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005)
683+
self.assertLess(results.diffs[0]["abs"], 0.005)
684+
685+
@requires_onnxruntime("1.24")
686+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
687+
def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
688+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
689+
qwen_sdpa_attention_loopa24_versatile,
690+
)
691+
692+
inputs = (
693+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
694+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
695+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
696+
self._get_seqlen(),
697+
)
698+
699+
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
700+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
701+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
702+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
703+
self.assertLess(results.diffs[0]["abs"], 1e-5)
704+
705+
results = qwen_sdpa_attention_loopa24_versatile.verify(
706+
*inputs, scaling=0.11180339887498948
707+
)
708+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
709+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
710+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
711+
self.assertLess(results.diffs[0]["abs"], 1e-5)
712+
647713

648714
if __name__ == "__main__":
649715
unittest.main(verbosity=2)

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -693,21 +693,7 @@ def forward(self, query, key, value, seq_lens):
693693
ks = key * mask
694694
vs = value * mask
695695
attn_output = qwen_sdpa_attention_loopmha_versatile(
696-
qs,
697-
ks,
698-
vs,
699-
seq_lens,
700-
0.11,
701-
16,
702-
(
703-
onnx.TensorProto.FLOAT
704-
if query.dtype == torch.float32
705-
else (
706-
onnx.TensorProto.FLOAT16
707-
if query.dtype == torch.float16
708-
else onnx.TensorProto.BFLOAT16
709-
)
710-
),
696+
qs, ks, vs, seq_lens, 0.11, 16
711697
)
712698
red = attn_output.mean(dim=-1, keepdim=True)
713699
return attn_output - red

onnx_diagnostic/export/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def find_method(self, name: Any):
149149

150150
if exporter in ("dynamo", "onnx-dynamo"):
151151
import os
152+
from ..helpers import flatten_object
152153
import onnxscript.rewriter.ort_fusions as ort_fusions
153154

154155
assert (
@@ -180,7 +181,12 @@ def find_method(self, name: Any):
180181
import onnx_ir as ir
181182
import onnx_ir.passes.common as common_passes
182183

183-
irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
184+
irfunctions = [
185+
ir.from_proto(
186+
plug.get_function_proto(*flatten_object((args, kwargs), drop_keys=True))
187+
)
188+
for plug in onnx_plugs
189+
]
184190
for func in irfunctions:
185191
epo.model.functions[func.identifier()] = func
186192
if inline:

0 commit comments

Comments
 (0)