Skip to content

Commit 60ee6ce

Browse files
authored
Refactors qwen 2.5 vl patches (#311)
* refactor qwen attention * doc * onnx_plug * fix test * mypy * fix inh
1 parent 89b137c commit 60ee6ce

File tree

9 files changed

+631
-67
lines changed

9 files changed

+631
-67
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.8.3
55
+++++
66

7-
* :pr:`310`: split patches into multiple files
7+
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
8+
* :pr:`310`: splits patches into multiple files
89
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
910
* :pr:`304`, :pr:`306`: improves side-by-side comparison, creates command line sbs
1011

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ onnx_diagnostic.export
88
api
99
control_flow
1010
dynamic_shapes
11+
onnx_plug
1112
shape_helper
1213
validate
1314

_doc/api/export/onnx_plug.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.onnx_plug
3+
================================
4+
5+
.. automodule:: onnx_diagnostic.export.onnx_plug
6+
:members:
7+
:no-undoc-members:

_unittests/ut_export/test_control_flow.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def body(i, x):
6666
ep = torch.export.export(
6767
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
6868
)
69-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
69+
self.assertIn(
70+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
71+
str(ep),
72+
)
7073

7174
onx = to_onnx(
7275
model,
@@ -97,7 +100,10 @@ def body(i, x):
97100
ep = torch.export.export(
98101
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
99102
)
100-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
103+
self.assertIn(
104+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
105+
str(ep),
106+
)
101107

102108
onx = to_onnx(
103109
model,
@@ -132,7 +138,10 @@ def body(i, x):
132138
ep = torch.export.export(
133139
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
134140
)
135-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
141+
self.assertIn(
142+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
143+
str(ep),
144+
)
136145

137146
onx = to_onnx(
138147
model,
@@ -164,7 +173,10 @@ def body(i, x):
164173
ep = torch.export.export(
165174
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
166175
)
167-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
176+
self.assertIn(
177+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
178+
str(ep),
179+
)
168180

169181
onx = to_onnx(
170182
model,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import unittest
2+
import onnx.helper as oh
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch
5+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
6+
from onnx_diagnostic.export.api import to_onnx
7+
8+
9+
class TestOnnxPlus(ExtTestCase):
10+
def test_onnx_plug_verify(self):
11+
def _test_customadd(x, y):
12+
return x + y
13+
14+
def _test_customadd_shape(x, y):
15+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
16+
17+
def make_function_proto():
18+
return oh.make_function(
19+
"onnx_plug",
20+
"_test_customadd",
21+
["x", "y"],
22+
["z"],
23+
[oh.make_node("Add", ["x", "y"], ["z"])],
24+
opset_imports=[oh.make_opsetid("", 22)],
25+
)
26+
27+
rep = EagerDirectReplacementWithOnnx(
28+
_test_customadd, _test_customadd_shape, make_function_proto(), 2, 1
29+
)
30+
31+
x = torch.randn((3, 4), dtype=torch.float32)
32+
y = torch.randn((3, 1), dtype=torch.float32)
33+
self.assertEqualArray(_test_customadd(x, y), x + y)
34+
res = rep.verify(x, y)
35+
self.assertEqualAny(res.eager_outputs, (x + y,))
36+
self.assertEqual(len(res.diffs), 1)
37+
self.assertEqual(res.diffs[0]["abs"], 0)
38+
39+
def test_onnx_plug_export(self):
40+
def _test_customsub(x, y):
41+
return x - y
42+
43+
def _test_customsub_shape(x, y):
44+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
45+
46+
def make_function_proto():
47+
return oh.make_function(
48+
"onnx_plug",
49+
"_test_customsub",
50+
["x", "y"],
51+
["z"],
52+
[oh.make_node("Sub", ["x", "y"], ["z"])],
53+
opset_imports=[oh.make_opsetid("", 22)],
54+
)
55+
56+
class Model(torch.nn.Module):
57+
def forward(self, x):
58+
y = x.sum(axis=1, keepdim=True)
59+
d = torch.ops.onnx_plug._test_customsub(x, y)
60+
return torch.abs(d)
61+
62+
replacements = [
63+
EagerDirectReplacementWithOnnx(
64+
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1
65+
)
66+
]
67+
68+
x = torch.randn((3, 4), dtype=torch.float32)
69+
model = Model()
70+
expected = model(x)
71+
ds = ({0: "d1", 1: "d2"},)
72+
ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds))
73+
self.assertIn("torch.ops.onnx_plug._test_customsub.default", str(ep))
74+
got = ep.module()(x)
75+
self.assertEqualArray(expected, got)
76+
77+
with self.subTest(exporter="custom"):
78+
onx = to_onnx(
79+
model,
80+
(x,),
81+
dynamic_shapes=ds,
82+
exporter="custom",
83+
onnx_plugs=replacements,
84+
target_opset=22,
85+
)
86+
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))
87+
88+
if not has_torch("2.9"):
89+
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
90+
with self.subTest(exporter="onnx-dynamo"):
91+
onx = to_onnx(
92+
model,
93+
(x,),
94+
dynamic_shapes=ds,
95+
exporter="onnx-dynamo",
96+
onnx_plugs=replacements,
97+
target_opset=22,
98+
)
99+
self.assert_onnx_disc(
100+
"test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,)
101+
)
102+
103+
104+
if __name__ == "__main__":
105+
unittest.main(verbosity=2)

onnx_diagnostic/export/api.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
22
import torch
3+
from .onnx_plug import EagerDirectReplacementWithOnnx
34

45

56
def to_onnx(
@@ -18,6 +19,8 @@ def to_onnx(
1819
save_ep: Optional[str] = None,
1920
optimize: bool = True,
2021
use_control_flow_dispatcher: bool = False,
22+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
23+
inline: bool = True,
2124
) -> Any:
2225
"""
2326
Common API for exporters. By default, the models are optimized to use the
@@ -41,6 +44,8 @@ def to_onnx(
4144
:param optimize: optimizes the model
4245
:param use_control_flow_dispatcher: use the dispatcher created to supported
4346
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
47+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
48+
:param inline: inline local functions
4449
:return: the output of the selected exporter, usually a structure including
4550
an onnx model
4651
@@ -55,24 +60,73 @@ def to_onnx(
5560
exporter=exporter,
5661
filename=filename,
5762
)
63+
64+
Some examples using control flows are available in
65+
:func:`onnx_diagnostic.export.control_flow.loop_for` or
66+
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
5867
"""
68+
if exporter_kwargs and "inline" in exporter_kwargs:
69+
assert (
70+
inline == exporter_kwargs["inline"]
71+
), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
72+
exporter_kwargs.pop("inline")
5973
if exporter == "custom":
6074
from experimental_experiment.torch_interpreter import (
6175
to_onnx as _to_onnx,
6276
ExportOptions,
6377
)
6478
from experimental_experiment.xbuilder import OptimizationOptions
6579

66-
if use_control_flow_dispatcher:
67-
from .control_flow import create_global_dispatcher
68-
69-
dispatcher = create_global_dispatcher()
70-
7180
options = None
7281
if exporter_kwargs is not None:
7382
options = exporter_kwargs.pop("options", None)
7483
if options is None:
7584
options = OptimizationOptions(patterns="default+onnxruntime")
85+
if onnx_plugs or use_control_flow_dispatcher:
86+
from experimental_experiment.torch_interpreter import Dispatcher
87+
88+
if use_control_flow_dispatcher:
89+
from .control_flow import create_global_dispatcher
90+
91+
control_flow_dispatcher = create_global_dispatcher()
92+
else:
93+
control_flow_dispatcher = None
94+
95+
class MainDispatcher(Dispatcher):
96+
def __init__(self, previous_dispatcher=None):
97+
super().__init__({})
98+
self.previous_dispatcher = previous_dispatcher
99+
100+
@property
101+
def supported(self):
102+
if self.previous_dispatcher:
103+
return (
104+
set(self.registered_functions) | self.previous_dispatcher.supported
105+
)
106+
return set(self.registered_functions)
107+
108+
def find_function(self, name: Any):
109+
if self.previous_dispatcher:
110+
find = self.previous_dispatcher.find_function(name)
111+
if find:
112+
return find
113+
return Dispatcher.find_function(self, name)
114+
115+
def find_method(self, name: Any):
116+
if self.previous_dispatcher:
117+
find = self.previous_dispatcher.find_method(name)
118+
if find:
119+
return find
120+
return Dispatcher.find_method(self, name)
121+
122+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
123+
if onnx_plugs:
124+
for plug in onnx_plugs:
125+
main_dispatcher.registered_functions[plug.target_name] = (
126+
plug.custom_converter()
127+
)
128+
else:
129+
main_dispatcher = None
76130

77131
return _to_onnx(
78132
mod,
@@ -88,8 +142,9 @@ def to_onnx(
88142
output_dynamic_shapes=output_dynamic_shapes,
89143
export_options=ExportOptions(save_ep=save_ep),
90144
options=options,
145+
inline=inline,
146+
dispatcher=main_dispatcher,
91147
**(exporter_kwargs or {}),
92-
dispatcher=dispatcher if use_control_flow_dispatcher else None,
93148
)
94149

95150
if exporter in ("dynamo", "onnx-dynamo"):
@@ -99,6 +154,10 @@ def to_onnx(
99154
assert (
100155
not output_dynamic_shapes
101156
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
157+
custom_translation_table = {}
158+
if onnx_plugs:
159+
for plug in onnx_plugs:
160+
custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
102161
epo = torch.onnx.export(
103162
mod,
104163
args=args or tuple(),
@@ -111,9 +170,24 @@ def to_onnx(
111170
verbose=verbose,
112171
dump_exported_program=bool(save_ep),
113172
artifacts_dir=os.path.dirname(filename) if filename else ".",
173+
custom_translation_table=custom_translation_table,
114174
**(exporter_kwargs or {}),
115175
)
116-
if optimize:
176+
if not inline and optimize:
177+
ort_fusions.optimize_for_ort(epo.model)
178+
179+
if onnx_plugs:
180+
import onnx_ir as ir
181+
import onnx_ir.passes.common as common_passes
182+
183+
irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
184+
for func in irfunctions:
185+
epo.model.functions[func.identifier()] = func
186+
if inline:
187+
common_passes.InlinePass()(epo.model)
188+
common_passes.RemoveUnusedOpsetsPass()(epo.model)
189+
190+
if inline and optimize:
117191
ort_fusions.optimize_for_ort(epo.model)
118192
if filename:
119193
epo.save(filename, external_data=True)

onnx_diagnostic/export/control_flow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def register(self, aten_name: str, converter: Callable):
3636

3737
@contextlib.contextmanager
3838
def enable_code_export_control_flow():
39-
"""Enables the code means to be exported."""
39+
"""Enables the code meant to be exported."""
4040
global _TEST_EXPORT
4141
old = _TEST_EXPORT
4242
_TEST_EXPORT = True
@@ -134,7 +134,8 @@ def make_custom_loop_for(
134134
assert body_outputs is not None, "body_outputs cannot be None"
135135
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136136
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137-
name = f"loop_for_{body_fn.__name__}_{id(body_fn)}_{srank}_{sred}"
137+
full_name = body_fn.__qualname__.replace("<locals>", "L").replace(".", "_")
138+
name = f"loop_for_{full_name}_{srank}_{sred}"
138139
if name in _REGISTERED_SCHEMA:
139140
return name, _REGISTERED_SCHEMA[name][0]
140141
sig = inspect.signature(body_fn)

0 commit comments

Comments
 (0)