Skip to content

Commit 4c129eb

Browse files
committed
fix a few things
1 parent add3086 commit 4c129eb

File tree

3 files changed

+87
-8
lines changed

3 files changed

+87
-8
lines changed

_unittests/ut_export/test_control_flow.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,40 @@ def body(i, x):
7878
self.dump_onnx("test_loop_one_custom.onnx", onx)
7979
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
8080

81+
@requires_torch("2.9.99")
82+
def test_loop_one_custom_different_opset(self):
83+
class Model(torch.nn.Module):
84+
def forward(self, n_iter, x):
85+
def body(i, x):
86+
return x[: i.item() + 1].unsqueeze(1)
87+
88+
return loop_for(n_iter, body, (x,))
89+
90+
model = Model()
91+
n_iter = torch.tensor(4, dtype=torch.int64)
92+
x = torch.arange(10, dtype=torch.float32)
93+
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
94+
got = model(n_iter, x)
95+
self.assertEqualArray(expected, got)
96+
97+
ep = torch.export.export(
98+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
99+
)
100+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
101+
102+
onx = to_onnx(
103+
model,
104+
(n_iter, x),
105+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
106+
exporter="custom",
107+
use_control_flow_dispatcher=True,
108+
target_opset=22,
109+
).model_proto
110+
opsets = {d.domain: d.version for d in onx.opset_import}
111+
self.assertEqual(opsets[""], 22)
112+
self.dump_onnx("test_loop_one_custom.onnx", onx)
113+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
114+
81115
@requires_torch("2.9.99")
82116
def test_loop_two_custom(self):
83117
class Model(torch.nn.Module):

onnx_diagnostic/export/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def to_onnx(
6666

6767
dispatcher = create_global_dispatcher()
6868

69+
options = None
70+
if exporter_kwargs is not None:
71+
options = exporter_kwargs.pop("options", None)
72+
if options is None:
73+
options = OptimizationOptions(patterns="default+onnxruntime")
74+
6975
return _to_onnx(
7076
mod,
7177
args=args,
@@ -79,7 +85,7 @@ def to_onnx(
7985
large_model=True,
8086
output_dynamic_shapes=output_dynamic_shapes,
8187
export_options=ExportOptions(save_ep=save_ep),
82-
options=OptimizationOptions(patterns="default+onnxruntime"),
88+
options=options,
8389
**(exporter_kwargs or {}),
8490
dispatcher=dispatcher if use_control_flow_dispatcher else None,
8591
)

onnx_diagnostic/export/control_flow.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,31 @@ def make_custom_loop_for(
148148
custom_def._abstract_fn = lambda *_args, _o=body_outputs: (
149149
tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
150150
)
151-
onx = convert_into_onnx(body_gm, args)
151+
152+
def _make_onx(
153+
body_gm=body_gm, args=args, target_opset=None, verbose=0, exporter_kwargs=None
154+
):
155+
return convert_into_onnx(
156+
body_gm,
157+
args,
158+
exporter_kwargs=exporter_kwargs,
159+
target_opset=target_opset,
160+
verbose=verbose,
161+
)
162+
152163
to_register = (
153164
custom_def,
154-
onx,
165+
_make_onx,
155166
(
156-
lambda g, sts, outputs, *args, body=onx, reduction_dim=reduction_dim, name=name: (
167+
lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
157168
convert_custom_loop_into_onnx(
158-
g, sts, outputs, *args, body=body, reduction_dim=reduction_dim, name=name
169+
g,
170+
sts,
171+
outputs,
172+
*args,
173+
body_callable=bc,
174+
reduction_dim=rd,
175+
name=name,
159176
)
160177
)
161178
),
@@ -173,7 +190,7 @@ def convert_custom_loop_into_onnx(
173190
sts: Dict[str, Any],
174191
outputs: List[str],
175192
*args: str,
176-
body: onnx.GraphProto,
193+
body_callable: Callable[..., onnx.ModelProto],
177194
reduction_dim: Optional[Sequence[int]] = None,
178195
name: str = "loop_for",
179196
) -> Union[str, List[str]]:
@@ -190,6 +207,14 @@ def convert_custom_loop_into_onnx(
190207
:param name: to give the onnx nodes a name
191208
:return: output names
192209
"""
210+
assert body_callable is not None, "body_callable cannot be None"
211+
# This should be part of a public API.
212+
body = body_callable(
213+
target_opset=g.main_opset,
214+
verbose=g.verbose,
215+
exporter_kwargs={"options": g.optimization_options},
216+
)
217+
193218
graph = body.graph if isinstance(body, onnx.ModelProto) else body
194219
assert isinstance(
195220
graph, onnx.GraphProto
@@ -261,19 +286,33 @@ def convert_custom_loop_into_onnx(
261286

262287

263288
def convert_into_onnx(
264-
body_gm: torch.fx.GraphModule, args: Sequence[torch.Tensor]
289+
body_gm: torch.fx.GraphModule,
290+
args: Sequence[torch.Tensor],
291+
target_opset: Optional[int] = None,
292+
verbose: int = 0,
293+
exporter_kwargs: Optional[Dict[str, Any]] = None,
265294
) -> onnx.ModelProto:
266295
"""
267296
Converts a torch.fx.GraphModule into ONNX.
268297
It returns a ModelProto.
269298
270299
:param body_gm: a torch.fx.GraphModule
271300
:param args: arguments known at export time
301+
:param target_opset: targetted opset
302+
:param verbose: verbosity level
303+
:param exporter_kwargs: additional exporter arguments
272304
:return: a ModelProto
273305
"""
274306
# This does not work with onnx-dynamo.
275307
# opset still needs to be defined
276-
container = to_onnx(body_gm, args, exporter="custom")
308+
container = to_onnx(
309+
body_gm,
310+
args,
311+
exporter="custom",
312+
exporter_kwargs=exporter_kwargs,
313+
target_opset=target_opset,
314+
verbose=verbose,
315+
)
277316
return container.model_proto
278317

279318

0 commit comments

Comments
 (0)