Skip to content

Commit 4a66796

Browse files
committed
fix test
1 parent e786ca0 commit 4a66796

File tree

5 files changed

+128
-10
lines changed

5 files changed

+128
-10
lines changed

_unittests/ut_export/test_control_flow.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def body(i, x):
6666
ep = torch.export.export(
6767
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
6868
)
69-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
69+
self.assertIn(
70+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
71+
str(ep),
72+
)
7073

7174
onx = to_onnx(
7275
model,
@@ -97,7 +100,10 @@ def body(i, x):
97100
ep = torch.export.export(
98101
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
99102
)
100-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
103+
self.assertIn(
104+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
105+
str(ep),
106+
)
101107

102108
onx = to_onnx(
103109
model,
@@ -132,7 +138,10 @@ def body(i, x):
132138
ep = torch.export.export(
133139
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
134140
)
135-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
141+
self.assertIn(
142+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
143+
str(ep),
144+
)
136145

137146
onx = to_onnx(
138147
model,
@@ -164,7 +173,10 @@ def body(i, x):
164173
ep = torch.export.export(
165174
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
166175
)
167-
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
176+
self.assertIn(
177+
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
178+
str(ep),
179+
)
168180

169181
onx = to_onnx(
170182
model,

_unittests/ut_export/test_onnx_plug.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import onnx.helper as oh
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch
55
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
66
from onnx_diagnostic.export.api import to_onnx
77

@@ -85,6 +85,8 @@ def forward(self, x):
8585
)
8686
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))
8787

88+
if not has_torch("2.9"):
89+
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
8890
with self.subTest(exporter="onnx-dynamo"):
8991
onx = to_onnx(
9092
model,

onnx_diagnostic/export/api.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def to_onnx(
2020
optimize: bool = True,
2121
use_control_flow_dispatcher: bool = False,
2222
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
23+
inline: bool = True,
2324
) -> Any:
2425
"""
2526
Common API for exporters. By default, the models are optimized to use the
@@ -44,6 +45,7 @@ def to_onnx(
4445
:param use_control_flow_dispatcher: use the dispatcher created to supported
4546
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
4647
:param onnx_plugs: the code was modified to replace some parts with onnx translation
48+
:param inline: inline local functions
4749
:return: the output of the selected exporter, usually a structure including
4850
an onnx model
4951
@@ -63,6 +65,11 @@ def to_onnx(
6365
:func:`onnx_diagnostic.export.control_flow.loop_for` or
6466
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
6567
"""
68+
if exporter_kwargs and "inline" in exporter_kwargs:
69+
assert (
70+
inline == exporter_kwargs["inline"]
71+
), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
72+
exporter_kwargs.pop("inline")
6673
if exporter == "custom":
6774
from experimental_experiment.torch_interpreter import (
6875
to_onnx as _to_onnx,
@@ -90,7 +97,7 @@ def __init__(self):
9097
super().__init__({})
9198

9299
main_dispatcher = MainDispatcher()
93-
if control_flow_dispatcher:
100+
if use_control_flow_dispatcher:
94101
main_dispatcher.registered_functions.update(
95102
control_flow_dispatcher.registered_functions
96103
)
@@ -99,7 +106,6 @@ def __init__(self):
99106
main_dispatcher.registered_functions[plug.target_name] = (
100107
plug.custom_converter()
101108
)
102-
103109
else:
104110
main_dispatcher = None
105111

@@ -117,8 +123,9 @@ def __init__(self):
117123
output_dynamic_shapes=output_dynamic_shapes,
118124
export_options=ExportOptions(save_ep=save_ep),
119125
options=options,
120-
**(exporter_kwargs or {}),
126+
inline=inline,
121127
dispatcher=main_dispatcher,
128+
**(exporter_kwargs or {}),
122129
)
123130

124131
if exporter in ("dynamo", "onnx-dynamo"):
@@ -147,7 +154,21 @@ def __init__(self):
147154
custom_translation_table=custom_translation_table,
148155
**(exporter_kwargs or {}),
149156
)
150-
if optimize:
157+
if not inline and optimize:
158+
ort_fusions.optimize_for_ort(epo.model)
159+
160+
if onnx_plugs:
161+
import onnx_ir as ir
162+
import onnx_ir.passes.common as common_passes
163+
164+
irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
165+
for func in irfunctions:
166+
epo.model.functions[func.identifier()] = func
167+
if inline:
168+
common_passes.InlinePass()(epo.model)
169+
common_passes.RemoveUnusedOpsetsPass()(epo.model)
170+
171+
if inline and optimize:
151172
ort_fusions.optimize_for_ort(epo.model)
152173
if filename:
153174
epo.save(filename, external_data=True)

onnx_diagnostic/export/control_flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def make_custom_loop_for(
134134
assert body_outputs is not None, "body_outputs cannot be None"
135135
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136136
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137-
name = f"loop_for_{body_fn.__name__}_{id(body_fn)}_{srank}_{sred}"
137+
full_name = body_fn.__qualname__.replace("<locals>", "L").replace(".", "_")
138+
name = f"loop_for_{full_name}_{srank}_{sred}"
138139
if name in _REGISTERED_SCHEMA:
139140
return name, _REGISTERED_SCHEMA[name][0]
140141
sig = inspect.signature(body_fn)

onnx_diagnostic/export/onnx_plug.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,88 @@ class EagerDirectReplacementWithOnnx:
4949
:param n_outputs: same for the number of outputs,
5050
only tensors must be counted
5151
:param name: the name of the custom op, the function name if not specified
52+
53+
Here is an example:
54+
55+
.. runpython::
56+
:showcode:
57+
58+
import onnx.helper as oh
59+
import torch
60+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
61+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
62+
from onnx_diagnostic.export.api import to_onnx
63+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
64+
65+
66+
def demo_customsub(x, y):
67+
return x - y
68+
69+
70+
def demo_customsub_shape(x, y):
71+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
72+
73+
74+
def make_function_proto():
75+
return oh.make_function(
76+
"onnx_plug",
77+
"demo_customsub",
78+
["x", "y"],
79+
["z"],
80+
[oh.make_node("Sub", ["x", "y"], ["z"])],
81+
opset_imports=[oh.make_opsetid("", 22)],
82+
)
83+
84+
85+
class Model(torch.nn.Module):
86+
def forward(self, x):
87+
y = x.sum(axis=1, keepdim=True)
88+
d = torch.ops.onnx_plug.demo_customsub(x, y)
89+
return torch.abs(d)
90+
91+
92+
replacements = [
93+
EagerDirectReplacementWithOnnx(
94+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
95+
)
96+
]
97+
98+
x = torch.randn((3, 4), dtype=torch.float32)
99+
model = Model()
100+
ds = ({0: "d1", 1: "d2"},)
101+
102+
# The exported program shows a custom op.
103+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
104+
print("ep")
105+
106+
# As the exporter knows how the replace this custom op.
107+
# Let's export.
108+
109+
onx = to_onnx(
110+
model,
111+
(x,),
112+
dynamic_shapes=ds,
113+
exporter="custom",
114+
onnx_plugs=replacements,
115+
target_opset=22,
116+
inline=False,
117+
).model_proto
118+
119+
print(pretty_onnx(onx))
120+
121+
# And with :func:`torch.onnx.export`:
122+
123+
onx = to_onnx(
124+
model,
125+
(x,),
126+
dynamic_shapes=ds,
127+
exporter="onnx-dynamo",
128+
onnx_plugs=replacements,
129+
target_opset=22,
130+
inline=False,
131+
).model_proto
132+
133+
print(pretty_onnx(onx))
52134
"""
53135

54136
def __init__(

0 commit comments

Comments
 (0)