Skip to content

Commit 9360a96

Browse files
authored
other research around loops (#328)
* other research * fix
1 parent f9d798e commit 9360a96

File tree

3 files changed

+249
-3
lines changed

3 files changed

+249
-3
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
4+
from onnx_diagnostic.export.control_flow import loop_for
5+
6+
7+
class TestControlFlow(ExtTestCase):
8+
@requires_torch("2.9.99")
9+
def test_loop_for(self):
10+
class Model(torch.nn.Module):
11+
def forward(self, n_iter, x):
12+
def body(i, x):
13+
return x[: i.item() + 1].unsqueeze(1)
14+
15+
return loop_for(n_iter, body, (x,))
16+
17+
model = Model()
18+
n_iter = torch.tensor(4, dtype=torch.int64)
19+
x = torch.arange(10, dtype=torch.float32)
20+
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
21+
got = model(n_iter, x)
22+
self.assertEqualArray(expected, got)
23+
24+
ep = torch.export.export(
25+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
26+
)
27+
names = set(m for m, _ in ep.module().named_modules())
28+
self.assertIn("", names)
29+
30+
31+
if __name__ == "__main__":
32+
unittest.main(verbosity=2)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import contextlib
2+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
3+
import torch
4+
from torch._higher_order_ops.utils import (
5+
materialize_as_graph,
6+
check_input_alias_and_mutation_return_outputs,
7+
# _maybe_reenter_make_fx,
8+
)
9+
10+
_TEST_EXPORT = False
11+
12+
13+
@contextlib.contextmanager
14+
def enable_code_export_control_flow():
15+
"""Enables the code meant to be exported."""
16+
global _TEST_EXPORT
17+
old = _TEST_EXPORT
18+
_TEST_EXPORT = True
19+
try:
20+
yield
21+
finally:
22+
_TEST_EXPORT = old
23+
24+
25+
def is_exporting() -> bool:
26+
"""
27+
Returns :func:`torch.compiler.is_exporting` or
28+
:func:`torch.compiler.is_compiling`.
29+
Changes ``_TEST_EXPORT`` to make it trigger.
30+
"""
31+
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
32+
33+
34+
def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
35+
"""
36+
Python implementation of the loop.
37+
38+
:param n_iter: number of iteration
39+
:param body_fn: function implementing the body
40+
:param reduction_dim: dimension used to reduce the list produced by the loop
41+
:param args: arguments to the loop body
42+
:return: results
43+
"""
44+
res = []
45+
for i in torch.arange(n_iter, dtype=n_iter.dtype):
46+
r = body_fn(i, *args)
47+
if isinstance(r, tuple):
48+
assert not res or len(r) == len(res[-1]), (
49+
f"Unexpected number of results {len(r)} for function {body_fn}, "
50+
f"expected {len(res[-1])}"
51+
)
52+
res.append(r)
53+
else:
54+
assert isinstance(r, torch.Tensor), (
55+
f"Unexpected type {r} for function {body_fn}, "
56+
f"it must be a tuple or a Tensor."
57+
)
58+
assert not res or len(res[-1]) == 1, (
59+
f"Unexpected number of results {len(r)} for function {body_fn}, "
60+
f"expected {len(res[-1])}"
61+
)
62+
res.append((r,))
63+
64+
if not res:
65+
return torch.empty(tuple(), dtype=torch.float32, device=args[0].device)
66+
if len(res) == 1:
67+
final = res[0]
68+
else:
69+
n_res = len(res[0])
70+
final = [
71+
torch.cat(
72+
[r[i] for r in res],
73+
dim=(
74+
0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
75+
),
76+
)
77+
for i in range(n_res)
78+
]
79+
return tuple(final) if len(final) > 1 else final[0]
80+
81+
82+
def make_custom_loop_for(
83+
n_iter: torch.Tensor,
84+
body_fn: Callable,
85+
reduction_dim: Optional[Sequence[int]],
86+
args: Sequence[torch.Tensor],
87+
body_gm: Optional[torch.fx.GraphModule] = None,
88+
body_mutated_inputs: Optional[List[Any]] = None,
89+
body_outputs: Optional[List[Any]] = None,
90+
) -> Tuple[str, torch.library.CustomOpDef]:
91+
"""
92+
Defines a custom operator for a loop in order to avoid
93+
:func:`torch.export.export` digging into it.
94+
It registers the custom op and a custom conversion
95+
to ONNX.
96+
97+
:param n_iter: number of iterations defined by a tensor of no dimension
98+
:param body_fn: the loop body defined as a function
99+
:param reduction_dim: dimension used to concatenated the results
100+
:param args: list of tensors, input to the body
101+
:param body_gm: torch.fx.GraphModule equivalent to *body_gm*
102+
:param body_mutated_inputs: inputs to *body_gm*
103+
:param body_outputs: outputs to *body_gm*
104+
:return: a name and the custom op definition, the name
105+
is used to cache the custom op
106+
"""
107+
assert body_gm is not None, "body_gm cannot be None"
108+
assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
109+
assert body_outputs is not None, "body_outputs cannot be None"
110+
111+
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
112+
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
113+
full_name = (
114+
body_fn.__qualname__.replace("<locals>", "L")
115+
.replace("<lambda>", "l")
116+
.replace(".", "_")
117+
)
118+
name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
119+
120+
schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor"
121+
if len(body_outputs) > 1:
122+
schema += "[]"
123+
custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn)
124+
custom_def.register_kernel("cpu")(body_fn)
125+
126+
custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: (
127+
tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
128+
)
129+
return name, custom_def
130+
131+
132+
def loop_for(
133+
n_iter: Union[torch.SymInt, torch.Tensor],
134+
body_fn: Callable[..., Tuple[torch.Tensor]],
135+
args: Sequence[torch.Tensor],
136+
reduction_dim: Optional[Sequence[int]] = None,
137+
) -> Tuple[torch.Tensor, ...]:
138+
"""
139+
High operators used to easily export a loop in ONNX.
140+
Does not fully work with :func:`torch.export.export`,
141+
it does replaces a custom op with a loop operator afterwards.
142+
Every iteration produces tensors, all of them are gathered
143+
into lists, all these lists are concatenated into tensors.
144+
145+
:param n_iter: number of iterations, it can be fixed on
146+
variable, in that case it should a tensor with no dimension
147+
:param body_fn: function body, takes only tensors and returns
148+
only tensors, the first argument is the iteration number
149+
in a tensor with no dimension, all the others
150+
are not changed during the loop
151+
:param args: the available tensors at every loop
152+
:param reduction_dim: the loop aggregated the results into list,
153+
one of each output, each of them is concatenated into one
154+
tensor along one dimension, by default, it is the first
155+
dimension, but it can be defined otherwise
156+
"""
157+
assert args, "The function should have at least one arg."
158+
assert (
159+
isinstance(n_iter, torch.Tensor)
160+
and n_iter.dtype == torch.int64
161+
and len(n_iter.shape) == 0
162+
), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
163+
if is_exporting():
164+
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
165+
166+
# tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer
167+
root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root
168+
# graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph
169+
170+
body_gm: torch.fx.GraphModule = materialize_as_graph(
171+
body_fn, (torch.tensor(0, dtype=torch.int64), *args)
172+
)
173+
(
174+
_1,
175+
_2,
176+
_3,
177+
body_mutated_inputs,
178+
body_outputs,
179+
) = check_input_alias_and_mutation_return_outputs(body_gm)
180+
name, _custom_ops = make_custom_loop_for(
181+
n_iter,
182+
body_fn,
183+
reduction_dim,
184+
args,
185+
body_gm=body_gm,
186+
body_mutated_inputs=body_mutated_inputs,
187+
body_outputs=body_outputs,
188+
)
189+
root.register_module(name, body_gm)
190+
# body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args)
191+
return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args)
192+
193+
return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
194+
195+
196+
"""
197+
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
198+
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
199+
200+
args = (cond_graph, body_graph, carried_inputs, additional_inputs)
201+
202+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
203+
204+
out_proxy = proxy_mode.tracer.create_proxy(
205+
"call_function", op, proxy_args, {}, name=op._name
206+
)
207+
208+
out = op(
209+
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
210+
)
211+
return track_tensor_tree(
212+
out, out_proxy, constant=None, tracer=proxy_mode.tracer
213+
)
214+
"""

onnx_diagnostic/export/control_flow_onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def loop_for_onnx(
364364
import torch
365365
import onnxruntime
366366
from onnx_diagnostic.export.api import to_onnx
367-
from onnx_diagnostic.export.control_flow import loop_for_onnx
367+
from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
368368
369369
370370
class Model(torch.nn.Module):
@@ -410,7 +410,7 @@ def body(i, x):
410410
import torch
411411
import onnxruntime
412412
from onnx_diagnostic.export.api import to_onnx
413-
from onnx_diagnostic.export.control_flow import loop_for_onnx
413+
from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
414414
415415
416416
class Model(torch.nn.Module):
@@ -457,7 +457,7 @@ def body(i, x):
457457
import torch
458458
import onnxruntime
459459
from onnx_diagnostic.export.api import to_onnx
460-
from onnx_diagnostic.export.control_flow import loop_for_onnx
460+
from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
461461
462462
463463
class Model(torch.nn.Module):

0 commit comments

Comments
 (0)