Skip to content

Commit e786ca0

Browse files
committed
onnx_plug
1 parent 4579cf9 commit e786ca0

File tree

6 files changed

+399
-7
lines changed

6 files changed

+399
-7
lines changed

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ onnx_diagnostic.export
88
api
99
control_flow
1010
dynamic_shapes
11+
onnx_plug
1112
shape_helper
1213
validate
1314

_doc/api/export/onnx_plug.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.onnx_plug
3+
================================
4+
5+
.. automodule:: onnx_diagnostic.export.onnx_plug
6+
:members:
7+
:no-undoc-members:
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import unittest
2+
import onnx.helper as oh
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
6+
from onnx_diagnostic.export.api import to_onnx
7+
8+
9+
class TestOnnxPlus(ExtTestCase):
10+
def test_onnx_plug_verify(self):
11+
def _test_customadd(x, y):
12+
return x + y
13+
14+
def _test_customadd_shape(x, y):
15+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
16+
17+
def make_function_proto():
18+
return oh.make_function(
19+
"onnx_plug",
20+
"_test_customadd",
21+
["x", "y"],
22+
["z"],
23+
[oh.make_node("Add", ["x", "y"], ["z"])],
24+
opset_imports=[oh.make_opsetid("", 22)],
25+
)
26+
27+
rep = EagerDirectReplacementWithOnnx(
28+
_test_customadd, _test_customadd_shape, make_function_proto(), 2, 1
29+
)
30+
31+
x = torch.randn((3, 4), dtype=torch.float32)
32+
y = torch.randn((3, 1), dtype=torch.float32)
33+
self.assertEqualArray(_test_customadd(x, y), x + y)
34+
res = rep.verify(x, y)
35+
self.assertEqualAny(res.eager_outputs, (x + y,))
36+
self.assertEqual(len(res.diffs), 1)
37+
self.assertEqual(res.diffs[0]["abs"], 0)
38+
39+
def test_onnx_plug_export(self):
40+
def _test_customsub(x, y):
41+
return x - y
42+
43+
def _test_customsub_shape(x, y):
44+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
45+
46+
def make_function_proto():
47+
return oh.make_function(
48+
"onnx_plug",
49+
"_test_customsub",
50+
["x", "y"],
51+
["z"],
52+
[oh.make_node("Sub", ["x", "y"], ["z"])],
53+
opset_imports=[oh.make_opsetid("", 22)],
54+
)
55+
56+
class Model(torch.nn.Module):
57+
def forward(self, x):
58+
y = x.sum(axis=1, keepdim=True)
59+
d = torch.ops.onnx_plug._test_customsub(x, y)
60+
return torch.abs(d)
61+
62+
replacements = [
63+
EagerDirectReplacementWithOnnx(
64+
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1
65+
)
66+
]
67+
68+
x = torch.randn((3, 4), dtype=torch.float32)
69+
model = Model()
70+
expected = model(x)
71+
ds = ({0: "d1", 1: "d2"},)
72+
ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds))
73+
self.assertIn("torch.ops.onnx_plug._test_customsub.default", str(ep))
74+
got = ep.module()(x)
75+
self.assertEqualArray(expected, got)
76+
77+
with self.subTest(exporter="custom"):
78+
onx = to_onnx(
79+
model,
80+
(x,),
81+
dynamic_shapes=ds,
82+
exporter="custom",
83+
onnx_plugs=replacements,
84+
target_opset=22,
85+
)
86+
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))
87+
88+
with self.subTest(exporter="onnx-dynamo"):
89+
onx = to_onnx(
90+
model,
91+
(x,),
92+
dynamic_shapes=ds,
93+
exporter="onnx-dynamo",
94+
onnx_plugs=replacements,
95+
target_opset=22,
96+
)
97+
self.assert_onnx_disc(
98+
"test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,)
99+
)
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main(verbosity=2)

onnx_diagnostic/export/api.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
22
import torch
3+
from .onnx_plug import EagerDirectReplacementWithOnnx
34

45

56
def to_onnx(
@@ -18,6 +19,7 @@ def to_onnx(
1819
save_ep: Optional[str] = None,
1920
optimize: bool = True,
2021
use_control_flow_dispatcher: bool = False,
22+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
2123
) -> Any:
2224
"""
2325
Common API for exporters. By default, the models are optimized to use the
@@ -41,6 +43,7 @@ def to_onnx(
4143
:param optimize: optimizes the model
4244
:param use_control_flow_dispatcher: use the dispatcher created to supported
4345
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
46+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
4447
:return: the output of the selected exporter, usually a structure including
4548
an onnx model
4649
@@ -55,6 +58,10 @@ def to_onnx(
5558
exporter=exporter,
5659
filename=filename,
5760
)
61+
62+
Some examples using control flows are available in
63+
:func:`onnx_diagnostic.export.control_flow.loop_for` or
64+
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
5865
"""
5966
if exporter == "custom":
6067
from experimental_experiment.torch_interpreter import (
@@ -63,16 +70,38 @@ def to_onnx(
6370
)
6471
from experimental_experiment.xbuilder import OptimizationOptions
6572

66-
if use_control_flow_dispatcher:
67-
from .control_flow import create_global_dispatcher
68-
69-
dispatcher = create_global_dispatcher()
70-
7173
options = None
7274
if exporter_kwargs is not None:
7375
options = exporter_kwargs.pop("options", None)
7476
if options is None:
7577
options = OptimizationOptions(patterns="default+onnxruntime")
78+
if onnx_plugs or use_control_flow_dispatcher:
79+
from experimental_experiment.torch_interpreter import Dispatcher
80+
81+
if use_control_flow_dispatcher:
82+
from .control_flow import create_global_dispatcher
83+
84+
control_flow_dispatcher = create_global_dispatcher()
85+
else:
86+
control_flow_dispatcher = None
87+
88+
class MainDispatcher(Dispatcher):
89+
def __init__(self):
90+
super().__init__({})
91+
92+
main_dispatcher = MainDispatcher()
93+
if control_flow_dispatcher:
94+
main_dispatcher.registered_functions.update(
95+
control_flow_dispatcher.registered_functions
96+
)
97+
if onnx_plugs:
98+
for plug in onnx_plugs:
99+
main_dispatcher.registered_functions[plug.target_name] = (
100+
plug.custom_converter()
101+
)
102+
103+
else:
104+
main_dispatcher = None
76105

77106
return _to_onnx(
78107
mod,
@@ -89,7 +118,7 @@ def to_onnx(
89118
export_options=ExportOptions(save_ep=save_ep),
90119
options=options,
91120
**(exporter_kwargs or {}),
92-
dispatcher=dispatcher if use_control_flow_dispatcher else None,
121+
dispatcher=main_dispatcher,
93122
)
94123

95124
if exporter in ("dynamo", "onnx-dynamo"):
@@ -99,6 +128,10 @@ def to_onnx(
99128
assert (
100129
not output_dynamic_shapes
101130
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
131+
custom_translation_table = {}
132+
if onnx_plugs:
133+
for plug in onnx_plugs:
134+
custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
102135
epo = torch.onnx.export(
103136
mod,
104137
args=args or tuple(),
@@ -111,6 +144,7 @@ def to_onnx(
111144
verbose=verbose,
112145
dump_exported_program=bool(save_ep),
113146
artifacts_dir=os.path.dirname(filename) if filename else ".",
147+
custom_translation_table=custom_translation_table,
114148
**(exporter_kwargs or {}),
115149
)
116150
if optimize:

onnx_diagnostic/export/control_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def register(self, aten_name: str, converter: Callable):
3636

3737
@contextlib.contextmanager
3838
def enable_code_export_control_flow():
39-
"""Enables the code means to be exported."""
39+
"""Enables the code meant to be exported."""
4040
global _TEST_EXPORT
4141
old = _TEST_EXPORT
4242
_TEST_EXPORT = True

0 commit comments

Comments
 (0)