Skip to content

Commit c329f21

Browse files
committed
fix a few things
1 parent ffcd801 commit c329f21

File tree

3 files changed

+234
-20
lines changed

3 files changed

+234
-20
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.2
55
+++++
66

7+
* :pr:`297`: experiment around a higher ops ``loop_for``
78
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models
89

910
0.8.1

_unittests/ut_export/test_control_flow.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,68 @@ def body(i, x):
7777
self.dump_onnx("test_loop_one_custom.onnx", onx)
7878
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
7979

80+
def test_loop_two_custom(self):
81+
class Model(torch.nn.Module):
82+
def forward(self, n_iter, x):
83+
def body(i, x):
84+
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
85+
86+
res = loop_for(n_iter, body, (x,))
87+
return res[0] + res[1]
88+
89+
model = Model()
90+
n_iter = torch.tensor(4, dtype=torch.int64)
91+
x = torch.arange(10, dtype=torch.float32)
92+
expected = torch.tensor([1, 1, 3, 1, 3, 5, 1, 3, 5, 7], dtype=x.dtype).unsqueeze(1)
93+
got = model(n_iter, x)
94+
self.assertEqualArray(expected, got)
95+
96+
ep = torch.export.export(
97+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
98+
)
99+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
100+
101+
onx = to_onnx(
102+
model,
103+
(n_iter, x),
104+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
105+
exporter="custom",
106+
use_control_flow_dispatcher=True,
107+
).model_proto
108+
self.dump_onnx("test_loop_one_custom.onnx", onx)
109+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
110+
111+
def test_loop_two_custom_reduction_dim(self):
112+
class Model(torch.nn.Module):
113+
def forward(self, n_iter, x):
114+
def body(i, x):
115+
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
116+
117+
res = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
118+
return res[0] + res[1].T
119+
120+
model = Model()
121+
n_iter = torch.tensor(4, dtype=torch.int64)
122+
x = torch.arange(10, dtype=torch.float32)
123+
expected = torch.tensor([1, 1, 3, 1, 3, 5, 1, 3, 5, 7], dtype=x.dtype).unsqueeze(1)
124+
got = model(n_iter, x)
125+
self.assertEqualArray(expected, got)
126+
127+
ep = torch.export.export(
128+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
129+
)
130+
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
131+
132+
onx = to_onnx(
133+
model,
134+
(n_iter, x),
135+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
136+
exporter="custom",
137+
use_control_flow_dispatcher=True,
138+
).model_proto
139+
self.dump_onnx("test_loop_one_custom.onnx", onx)
140+
self.assert_onnx_disc("test_loop_one_custom", onx, model, (n_iter, x))
141+
80142

81143
if __name__ == "__main__":
82144
unittest.main(verbosity=2)

onnx_diagnostic/export/control_flow.py

Lines changed: 171 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
from torch._higher_order_ops.utils import materialize_as_graph
88
from torch._higher_order_ops.utils import check_input_alias_and_mutation_return_outputs
9-
from ..helpers.onnx_helper import pretty_onnx
109
from .api import to_onnx
1110

1211
_TEST_EXPORT = False
@@ -47,15 +46,25 @@ def enable_code_export_control_flow():
4746
_TEST_EXPORT = old
4847

4948

50-
def is_exporting():
49+
def is_exporting() -> bool:
5150
"""
5251
Returns :func:`torch.compiler.is_exporting` or
53-
:func:`torch.compiler.is_compiling`
52+
:func:`torch.compiler.is_compiling`.
53+
Changes ``_TEST_EXPORT`` to make it trigger.
5454
"""
5555
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
5656

5757

5858
def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
59+
"""
60+
Python implementation of the loop.
61+
62+
:param n_iter: number of iteration
63+
:param body_fn: function implementating the body
64+
:param reduction_dim: dimension used to reduce the list produced by the loop
65+
:param args: arguments to the loop body
66+
:return: results
67+
"""
5968
res = []
6069
for i in torch.arange(n_iter, dtype=n_iter.dtype):
6170
r = body_fn(i, *args)
@@ -95,14 +104,30 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
95104

96105

97106
def make_custom_loop_for(
98-
n_iter,
99-
body_fn,
100-
reduction_dim,
101-
args,
102-
body_gm=None,
103-
body_mutated_inputs=None,
104-
body_outputs=None,
105-
):
107+
n_iter: torch.Tensor,
108+
body_fn: Callable,
109+
reduction_dim: Optional[List[int]],
110+
args: List[torch.Tensor],
111+
body_gm: Optional[torch.fx.GraphModule] = None,
112+
body_mutated_inputs: Optional[List[Any]] = None,
113+
body_outputs: Optional[List[Any]] = None,
114+
) -> Tuple[str, torch.library.CustomOpDef]:
115+
"""
116+
Defines a custom operator for a loop in order to avoid
117+
:func:`torch.export.export` digging into it.
118+
It registers the custom op and a custom conversion
119+
to ONNX.
120+
121+
:param n_iter: number of iterations defined by a tensor of no dimension
122+
:param body_fn: the loop body defined as a function
123+
:param reduction_dim: dimension used to concatenated the results
124+
:param args: list of tensors, input to the body
125+
:param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126+
:param body_mutated_inputs: inputs to *body_gm*
127+
:param body_outputs: outputs to *body_gm*
128+
:return: a name and the custom op definition, the name
129+
is used to cache the custom op
130+
"""
106131
global _DISPATCHER
107132
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
108133
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
@@ -120,7 +145,7 @@ def make_custom_loop_for(
120145
custom_def._abstract_fn = lambda *_args, _o=body_outputs: (
121146
tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
122147
)
123-
onx = convert_into_onnx(body_gm, args, body_mutated_inputs, body_outputs)
148+
onx = convert_into_onnx(body_gm, args)
124149
to_register = (
125150
custom_def,
126151
onx,
@@ -149,15 +174,26 @@ def convert_custom_loop_into_onnx(
149174
reduction_dim: Optional[Tuple[int, ...]] = None,
150175
name: str = "loop_for",
151176
) -> Union[str, Tuple[str, ...]]:
177+
"""
178+
Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
179+
180+
:param g: GreaphBuilder
181+
:param sts: if not defined, torch does not know the output shapes
182+
:param outputs: output names
183+
:param args: input argument known at export time
184+
:param body: GraphProto, the loop body
185+
:param reduction_dim: the dimension to follow when aggregating the
186+
list of tensors after the loop ran
187+
:param name: to give the onnx nodes a name
188+
:return: output names
189+
"""
152190
graph = body.graph if isinstance(body, onnx.ModelProto) else body
153191
assert isinstance(
154192
graph, onnx.GraphProto
155193
), f"Unexpected type {type(body)} for body{g.get_debug_msg()}"
156-
assert len(outputs) == len(graph.output), (
157-
f"Length mismatch, expecting {len(outputs)} but got "
158-
f"{len(graph.output)}, \n--\n{pretty_onnx(body)}"
159-
f"{g.get_debug_msg()}"
160-
)
194+
assert len(outputs) == 1, f"Only one outputs is expected but outputs={outputs!r}"
195+
if len(graph.output) != 1:
196+
outputs = [f"{outputs[0]}#{i}" for i in range(len(graph.output))]
161197
input_names = [i.name for i in graph.input]
162198
inputs = [
163199
*graph.input[:1],
@@ -218,11 +254,20 @@ def convert_custom_loop_into_onnx(
218254
for i, o in enumerate(outputs):
219255
g.set_type(o, graph.output[i].type.tensor_type.elem_type)
220256
g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
221-
return tuple(outputs) if len(outputs) > 1 else outputs[0]
257+
return outputs if len(outputs) > 1 else outputs[0]
258+
222259

260+
def convert_into_onnx(
261+
body_gm: torch.fx.GraphModule, args: List[torch.Tensor]
262+
) -> onnx.ModelProto:
263+
"""
264+
Converts a torch.fx.GraphModule into ONNX.
265+
It returns a ModelProto.
223266
224-
def convert_into_onnx(body_gm, args, body_mutated_inputs, body_outputs):
225-
"""Converts into ONNX."""
267+
:param body_gm: a torch.fx.GraphModule
268+
:param args: arguments known at export time
269+
:return: a ModelProto
270+
"""
226271
# This does not work with onnx-dynamo.
227272
# opset still needs to be defined
228273
container = to_onnx(body_gm, args, exporter="custom")
@@ -239,6 +284,20 @@ def loop_for(
239284
High operators used to easily export a loop in ONNX.
240285
Does not fully work with :func:`torch.export.export`,
241286
it does replaces a custom op with a loop operator afterwards.
287+
Every iteration produces tensors, all of them are gathered
288+
into lists, all these lists are concatenated into tensors.
289+
290+
:param n_iter: number of iterations, it can be fixed on
291+
variable, in that case it should a tensor with no dimension
292+
:param body_fn: function body, takes only tensors and returns
293+
only tensors, the first argument is the iteration number
294+
in a tensor with no dimension, all the others
295+
are not changed during the loop
296+
:param args: the available tensors at every loop
297+
:param reduction_dim: the loop aggregated the results into list,
298+
one of each output, each of them is concatenated into one
299+
tensor along one dimension, by default, it is the first
300+
dimension, but it can be defined otherwise
242301
243302
.. runpython::
244303
:showcode:
@@ -271,6 +330,52 @@ def body(i, x):
271330
use_control_flow_dispatcher=True,
272331
).model_proto
273332
333+
sess = onnxruntime.InferenceSession(
334+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
335+
)
336+
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
337+
print("got:", got)
338+
339+
340+
# The loop is exported as a custom ops.
341+
ep = torch.export.export(
342+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
343+
)
344+
print(ep)
345+
346+
Another example with two outputs:
347+
348+
.. runpython::
349+
:showcode:
350+
351+
import torch
352+
import onnxruntime
353+
from onnx_diagnostic.export.api import to_onnx
354+
from onnx_diagnostic.export.control_flow import loop_for
355+
356+
357+
class Model(torch.nn.Module):
358+
def forward(self, n_iter, x):
359+
def body(i, x):
360+
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
361+
362+
two = loop_for(n_iter, body, (x,))
363+
return two[0] + two[1]
364+
365+
366+
model = Model()
367+
n_iter = torch.tensor(4, dtype=torch.int64)
368+
x = torch.arange(10, dtype=torch.float32)
369+
expected = model(n_iter, x)
370+
print("expected:", expected)
371+
372+
onx = to_onnx(
373+
model,
374+
(n_iter, x),
375+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
376+
exporter="custom",
377+
use_control_flow_dispatcher=True,
378+
).model_proto
274379
275380
sess = onnxruntime.InferenceSession(
276381
onx.SerializeToString(), providers=["CPUExecutionProvider"]
@@ -285,6 +390,52 @@ def body(i, x):
285390
)
286391
print(ep)
287392
393+
A last example with ``reduction_dim``:
394+
395+
.. runpython::
396+
:showcode:
397+
398+
import torch
399+
import onnxruntime
400+
from onnx_diagnostic.export.api import to_onnx
401+
from onnx_diagnostic.export.control_flow import loop_for
402+
403+
404+
class Model(torch.nn.Module):
405+
def forward(self, n_iter, x):
406+
def body(i, x):
407+
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
408+
409+
two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
410+
return two[0] + two[1].T
411+
412+
413+
model = Model()
414+
n_iter = torch.tensor(4, dtype=torch.int64)
415+
x = torch.arange(10, dtype=torch.float32)
416+
expected = model(n_iter, x)
417+
print("expected:", expected)
418+
419+
onx = to_onnx(
420+
model,
421+
(n_iter, x),
422+
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
423+
exporter="custom",
424+
use_control_flow_dispatcher=True,
425+
).model_proto
426+
427+
sess = onnxruntime.InferenceSession(
428+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
429+
)
430+
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
431+
print("got:", got)
432+
433+
434+
# The loop is exported as a custom ops.
435+
ep = torch.export.export(
436+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
437+
)
438+
print(ep)
288439
"""
289440
assert args, "The function should have at least one arg."
290441
assert (

0 commit comments

Comments
 (0)