Skip to content

Commit eccdc07

Browse files
authored
Implements higher ops loop_for for ONNX (experimental) (#297)
* first draft * fix * loop * fix * fix a few things * fix * mypy * fix * dis * fix a few things * spell
1 parent 0d3bd28 commit eccdc07

File tree

8 files changed

+852
-2
lines changed

8 files changed

+852
-2
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.2
55
+++++
66

7+
* :pr:`297`: experiment around a higher ops ``loop_for``
78
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models
89

910
0.8.1

_doc/api/export/control_flow.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.export.control_flow
3+
===================================
4+
5+
.. automodule:: onnx_diagnostic.export.control_flow
6+
:members:

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ onnx_diagnostic.export
66
:caption: modules
77

88
api
9+
control_flow
910
dynamic_shapes
1011
shape_helper
1112
validate
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import unittest
2+
from typing import Tuple
3+
import torch
4+
from onnxscript import script, FLOAT, INT64
5+
from onnxscript import opset18 as op
6+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
7+
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
8+
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
9+
from onnx_diagnostic.export.api import to_onnx
10+
11+
12+
class TestControlFlow(ExtTestCase):
13+
@unittest.skip("not working")
14+
def test_loop_one_research(self):
15+
class Model(torch.nn.Module):
16+
def forward(self, n_iter, x):
17+
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
18+
return (x[: i.item() + 1].unsqueeze(1),)
19+
20+
return loop_for_r(n_iter, body, (x,))[0]
21+
22+
model = Model()
23+
n_iter = torch.tensor(4, dtype=torch.int64)
24+
x = torch.arange(10, dtype=torch.float32)
25+
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
26+
got = model(n_iter, x)
27+
self.assertEqualArray(expected, got)
28+
29+
with enable_code_export_control_flow():
30+
got = model(n_iter, x)
31+
self.assertEqualArray(expected, got)
32+
33+
ep = torch.export.export(
34+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
35+
)
36+
print(ep)
37+
38+
def test_onnxscript_loop(self):
39+
@script()
40+
def concatenation(N: INT64[1], x: FLOAT[None]) -> FLOAT[None, 1]:
41+
copy = op.Identity(x)
42+
res = op.SequenceEmpty()
43+
for i in range(N):
44+
res = op.SequenceInsert(res, op.Unsqueeze(copy[:i], [1]))
45+
return op.ConcatFromSequence(res, axis=1)
46+
47+
onx = concatenation.to_model_proto()
48+
self.dump_onnx("test_onnxscript_loop.onnx", onx)
49+
50+
@requires_torch("2.9.99")
51+
def test_loop_one_custom(self):
52+
class Model(torch.nn.Module):
53+
def forward(self, n_iter, x):
54+
def body(i, x):
55+
return x[: i.item() + 1].unsqueeze(1)
56+
57+
return loop_for(n_iter, body, (x,))
58+
59+
model = Model()
60+
n_iter = torch.tensor(4, dtype=torch.int64)
61+
x = torch.arange(10, dtype=torch.float32)
62+
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
63+
got = model(n_iter, x)
64+
self.assertEqualArray(expected, got)
65+
66+
ep = torch.export.export(
67+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
68+
)
69+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
70+
71+
onx = to_onnx(
72+
model,
73+
(n_iter, x),
74+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
75+
exporter="custom",
76+
use_control_flow_dispatcher=True,
77+
).model_proto
78+
self.dump_onnx("test_loop_one_custom.onnx", onx)
79+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
80+
81+
@requires_torch("2.9.99")
82+
def test_loop_one_custom_different_opset(self):
83+
class Model(torch.nn.Module):
84+
def forward(self, n_iter, x):
85+
def body(i, x):
86+
return x[: i.item() + 1].unsqueeze(1)
87+
88+
return loop_for(n_iter, body, (x,))
89+
90+
model = Model()
91+
n_iter = torch.tensor(4, dtype=torch.int64)
92+
x = torch.arange(10, dtype=torch.float32)
93+
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
94+
got = model(n_iter, x)
95+
self.assertEqualArray(expected, got)
96+
97+
ep = torch.export.export(
98+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
99+
)
100+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
101+
102+
onx = to_onnx(
103+
model,
104+
(n_iter, x),
105+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
106+
exporter="custom",
107+
use_control_flow_dispatcher=True,
108+
target_opset=22,
109+
).model_proto
110+
opsets = {d.domain: d.version for d in onx.opset_import}
111+
self.assertEqual(opsets[""], 22)
112+
self.dump_onnx("test_loop_one_custom.onnx", onx)
113+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
114+
115+
@requires_torch("2.9.99")
116+
def test_loop_two_custom(self):
117+
class Model(torch.nn.Module):
118+
def forward(self, n_iter, x):
119+
def body(i, x):
120+
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
121+
122+
res = loop_for(n_iter, body, (x,))
123+
return res[0] + res[1]
124+
125+
model = Model()
126+
n_iter = torch.tensor(4, dtype=torch.int64)
127+
x = torch.arange(10, dtype=torch.float32)
128+
expected = torch.tensor([1, 1, 3, 1, 3, 5, 1, 3, 5, 7], dtype=x.dtype).unsqueeze(1)
129+
got = model(n_iter, x)
130+
self.assertEqualArray(expected, got)
131+
132+
ep = torch.export.export(
133+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
134+
)
135+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
136+
137+
onx = to_onnx(
138+
model,
139+
(n_iter, x),
140+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
141+
exporter="custom",
142+
use_control_flow_dispatcher=True,
143+
).model_proto
144+
self.dump_onnx("test_loop_one_custom.onnx", onx)
145+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
146+
147+
@requires_torch("2.9.99")
148+
def test_loop_two_custom_reduction_dim(self):
149+
class Model(torch.nn.Module):
150+
def forward(self, n_iter, x):
151+
def body(i, x):
152+
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
153+
154+
res = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
155+
return res[0] + res[1].T
156+
157+
model = Model()
158+
n_iter = torch.tensor(4, dtype=torch.int64)
159+
x = torch.arange(10, dtype=torch.float32)
160+
expected = torch.tensor([1, 1, 3, 1, 3, 5, 1, 3, 5, 7], dtype=x.dtype).unsqueeze(1)
161+
got = model(n_iter, x)
162+
self.assertEqualArray(expected, got)
163+
164+
ep = torch.export.export(
165+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
166+
)
167+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
168+
169+
onx = to_onnx(
170+
model,
171+
(n_iter, x),
172+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
173+
exporter="custom",
174+
use_control_flow_dispatcher=True,
175+
).model_proto
176+
self.dump_onnx("test_loop_one_custom.onnx", onx)
177+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
178+
179+
180+
if __name__ == "__main__":
181+
unittest.main(verbosity=2)

_unittests/ut_export/test_jit.py renamed to _unittests/ut_export/test_experiment_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def select_when_exporting(f, f_scan):
5151
return f_scan if is_torchdynamo_exporting() else f
5252

5353

54-
class TestJit(ExtTestCase):
54+
class TestExperimentJit(ExtTestCase):
5555
def test_dummy_loop(self):
5656
x = torch.randn((5, 6))
5757
y = torch.arange(5, dtype=torch.int64) + 1

onnx_diagnostic/export/api.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def to_onnx(
1616
exporter: str = "onnx-dynamo",
1717
exporter_kwargs: Optional[Dict[str, Any]] = None,
1818
save_ep: Optional[str] = None,
19+
use_control_flow_dispatcher: bool = False,
1920
) -> Any:
2021
"""
2122
Common API for exporters. By default, the models are optimized to use the
@@ -36,6 +37,8 @@ def to_onnx(
3637
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
3738
:param exporter_kwargs: additional parameters sent to the exporter
3839
:param save_ep: saves the exported program
40+
:param use_control_flow_dispatcher: use the dispatcher created to supported
41+
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
3942
:return: the output of the selected exporter, usually a structure including
4043
an onnx model
4144
@@ -58,6 +61,17 @@ def to_onnx(
5861
)
5962
from experimental_experiment.xbuilder import OptimizationOptions
6063

64+
if use_control_flow_dispatcher:
65+
from .control_flow import create_global_dispatcher
66+
67+
dispatcher = create_global_dispatcher()
68+
69+
options = None
70+
if exporter_kwargs is not None:
71+
options = exporter_kwargs.pop("options", None)
72+
if options is None:
73+
options = OptimizationOptions(patterns="default+onnxruntime")
74+
6175
return _to_onnx(
6276
mod,
6377
args=args,
@@ -71,8 +85,9 @@ def to_onnx(
7185
large_model=True,
7286
output_dynamic_shapes=output_dynamic_shapes,
7387
export_options=ExportOptions(save_ep=save_ep),
74-
options=OptimizationOptions(patterns="default+onnxruntime"),
88+
options=options,
7589
**(exporter_kwargs or {}),
90+
dispatcher=dispatcher if use_control_flow_dispatcher else None,
7691
)
7792
if exporter in ("dynamo", "onnx-dynamo"):
7893
import onnxscript.rewriter.ort_fusions as ort_fusions

0 commit comments

Comments
 (0)