Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.8.4
+++++

* :pr:`336`: implements versioned onnx plugs

0.8.3
+++++

Expand Down
39 changes: 26 additions & 13 deletions _unittests/ut_tasks/try_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import onnx
import textwrap
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test, ignore_warnings
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
never_test,
ignore_warnings,
has_onnxruntime,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.export.api import to_onnx
Expand All @@ -17,6 +22,11 @@ class TestTryExportHuggingFaceHubModel(ExtTestCase):
@ignore_warnings(UserWarning)
def test_qwen25_vli_visual(self):
"""
unittest::

UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python \\
_unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual

# task: imagetext2text
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_export.py -k qwen_2_5

Expand Down Expand Up @@ -148,10 +158,12 @@ def _config_reduction(config, task):
elif device == "cuda" and dtype in ("float16", "bfloat16"):
attention_options = ["PACKED", "BIGMASK"]
else:
attention_options = ["LOOPMHA", "BIGMASK"]
attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"]

# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
for attention in attention_options:
if attention == "LOOPA24" and not has_onnxruntime("1.24"):
continue
with self.subTest(attention=attention):
print()
print(f"-- attention={attention!r}")
Expand Down Expand Up @@ -180,7 +192,7 @@ def _config_reduction(config, task):
exporter=exporter,
verbose=1,
save_ep=None if self.unit_test_going() else (fileep, 2**35),
target_opset=22,
target_opset=24 if attention == "LOOPA24" else 22,
optimize=True,
onnx_plugs=PLUGS,
)
Expand All @@ -207,17 +219,18 @@ def _config_reduction(config, task):
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
model = onnx.load(filename, load_external_data=False)
if attention == "PACKED":
self.assertIn(
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
)
self.assertIn('"PackedMultiHeadAttention"', str(model))
elif attention == "BIGMASK":
self.assertNotIn(
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
)
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
self.assertIn("MultiHeadAttention", str(model))
self.assertNotIn("Loop", {n.op_type for n in model.graph.node})
elif attention == "LOOPMHA":
self.assertNotIn(
"PackedMultiHeadAttention", {n.op_type for n in model.graph.node}
)
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
self.assertIn('"MultiHeadAttention"', str(model))
self.assertIn("Loop", {n.op_type for n in model.graph.node})
elif attention == "LOOPA24":
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
self.assertNotIn('"MultiHeadAttention"', str(model))
self.assertIn("Loop", {n.op_type for n in model.graph.node})
else:
raise AssertionError(f"attention={attention!r} not expected")
Expand Down Expand Up @@ -251,7 +264,7 @@ def _config_reduction(config, task):
else ["CPUExecutionProvider"]
),
use_ort=True,
atol=0.02,
atol=0.05,
rtol=10,
# ep=pt2_file,
expected=expected,
Expand Down
200 changes: 133 additions & 67 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
requires_cuda,
requires_onnxruntime,
requires_transformers,
requires_torch,
ignore_warnings,
Expand Down Expand Up @@ -519,9 +520,43 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
)
self.clean_dump()

@classmethod
def _get_seqlen(cls) -> torch.Tensor:
return torch.tensor(
[
0,
64,
128,
192,
256,
304,
368,
432,
496,
560,
608,
672,
736,
800,
864,
912,
976,
1040,
1104,
1168,
1216,
1232,
1248,
1264,
1280,
1292,
],
dtype=torch.int64,
)

@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
@requires_cuda()
def test_plug_packed_multi_head_attention_qwen25_packed(self):
def test_plug_multi_head_attention_qwen25_packed_float16(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_packed_versatile,
)
Expand All @@ -530,37 +565,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
torch.tensor(
[
0,
64,
128,
192,
256,
304,
368,
432,
496,
560,
608,
672,
736,
800,
864,
912,
976,
1040,
1104,
1168,
1216,
1232,
1248,
1264,
1280,
1292,
],
dtype=torch.int64,
).to("cuda"),
self._get_seqlen().to("cuda"),
)

results = qwen_sdpa_attention_packed_versatile.verify(
Expand All @@ -579,8 +584,9 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

@requires_onnxruntime("1.24")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
def test_plug_multi_head_attention_qwen25_loopmha_float16(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopmha_versatile,
)
Expand All @@ -589,46 +595,15 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.tensor(
[
0,
64,
128,
192,
256,
304,
368,
432,
496,
560,
608,
672,
736,
800,
864,
912,
976,
1040,
1104,
1168,
1216,
1232,
1248,
1264,
1280,
1292,
],
dtype=torch.int64,
),
self._get_seqlen(),
)

results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs,
scaling=0.5,
num_heads=16,
itype=onnx.TensorProto.FLOAT16,
dump_onnx_model=self.get_dump_file(
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
),
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
Expand All @@ -637,13 +612,104 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
self.assertLess(results.diffs[0]["abs"], 0.01)

results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs, scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT16
*inputs, scaling=0.11180339887498948, num_heads=16
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

@requires_onnxruntime("1.24")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopmha_float32(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopmha_versatile,
)

inputs = (
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
self._get_seqlen(),
)

results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs,
scaling=0.5,
num_heads=16,
dump_onnx_model=self.get_dump_file(
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
),
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
self.assertLess(results.diffs[0]["abs"], 1e-5)

results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs, scaling=0.11180339887498948, num_heads=16
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
self.assertLess(results.diffs[0]["abs"], 1e-5)

@requires_onnxruntime("1.24")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopa24_float16(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopa24_versatile,
)

inputs = (
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
self._get_seqlen(),
)

results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2)
self.assertLess(results.diffs[0]["abs"], 1e-2)

results = qwen_sdpa_attention_loopa24_versatile.verify(
*inputs, scaling=0.11180339887498948
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005)
self.assertLess(results.diffs[0]["abs"], 0.005)

@requires_onnxruntime("1.24")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopa24_versatile,
)

inputs = (
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
self._get_seqlen(),
)

results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
self.assertLess(results.diffs[0]["abs"], 1e-5)

results = qwen_sdpa_attention_loopa24_versatile.verify(
*inputs, scaling=0.11180339887498948
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
self.assertLess(results.diffs[0]["abs"], 1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)
16 changes: 1 addition & 15 deletions _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,21 +693,7 @@ def forward(self, query, key, value, seq_lens):
ks = key * mask
vs = value * mask
attn_output = qwen_sdpa_attention_loopmha_versatile(
qs,
ks,
vs,
seq_lens,
0.11,
16,
(
onnx.TensorProto.FLOAT
if query.dtype == torch.float32
else (
onnx.TensorProto.FLOAT16
if query.dtype == torch.float16
else onnx.TensorProto.BFLOAT16
)
),
qs, ks, vs, seq_lens, 0.11, 16
)
red = attn_output.mean(dim=-1, keepdim=True)
return attn_output - red
Expand Down
8 changes: 7 additions & 1 deletion onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def find_method(self, name: Any):

if exporter in ("dynamo", "onnx-dynamo"):
import os
from ..helpers import flatten_object
import onnxscript.rewriter.ort_fusions as ort_fusions

assert (
Expand Down Expand Up @@ -180,7 +181,12 @@ def find_method(self, name: Any):
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
irfunctions = [
ir.from_proto(
plug.get_function_proto(*flatten_object((args, kwargs), drop_keys=True))
)
for plug in onnx_plugs
]
for func in irfunctions:
epo.model.functions[func.identifier()] = func
if inline:
Expand Down
Loading
Loading