66import torch
77from torch ._higher_order_ops .utils import materialize_as_graph
88from torch ._higher_order_ops .utils import check_input_alias_and_mutation_return_outputs
9- from ..helpers .onnx_helper import pretty_onnx
109from .api import to_onnx
1110
1211_TEST_EXPORT = False
@@ -47,15 +46,25 @@ def enable_code_export_control_flow():
4746 _TEST_EXPORT = old
4847
4948
50- def is_exporting ():
49+ def is_exporting () -> bool :
5150 """
5251 Returns :func:`torch.compiler.is_exporting` or
53- :func:`torch.compiler.is_compiling`
52+ :func:`torch.compiler.is_compiling`.
53+ Changes ``_TEST_EXPORT`` to make it trigger.
5454 """
5555 return _TEST_EXPORT or torch .compiler .is_exporting () or torch .compiler .is_compiling ()
5656
5757
5858def _loop_for_fn (n_iter , body_fn , reduction_dim , args ):
59+ """
60+ Python implementation of the loop.
61+
62+ :param n_iter: number of iteration
63+ :param body_fn: function implementating the body
64+ :param reduction_dim: dimension used to reduce the list produced by the loop
65+ :param args: arguments to the loop body
66+ :return: results
67+ """
5968 res = []
6069 for i in torch .arange (n_iter , dtype = n_iter .dtype ):
6170 r = body_fn (i , * args )
@@ -95,14 +104,30 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
95104
96105
97106def make_custom_loop_for (
98- n_iter ,
99- body_fn ,
100- reduction_dim ,
101- args ,
102- body_gm = None ,
103- body_mutated_inputs = None ,
104- body_outputs = None ,
105- ):
107+ n_iter : torch .Tensor ,
108+ body_fn : Callable ,
109+ reduction_dim : Optional [List [int ]],
110+ args : List [torch .Tensor ],
111+ body_gm : Optional [torch .fx .GraphModule ] = None ,
112+ body_mutated_inputs : Optional [List [Any ]] = None ,
113+ body_outputs : Optional [List [Any ]] = None ,
114+ ) -> Tuple [str , torch .library .CustomOpDef ]:
115+ """
116+ Defines a custom operator for a loop in order to avoid
117+ :func:`torch.export.export` digging into it.
118+ It registers the custom op and a custom conversion
119+ to ONNX.
120+
121+ :param n_iter: number of iterations defined by a tensor of no dimension
122+ :param body_fn: the loop body defined as a function
123+ :param reduction_dim: dimension used to concatenated the results
124+ :param args: list of tensors, input to the body
125+ :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126+ :param body_mutated_inputs: inputs to *body_gm*
127+ :param body_outputs: outputs to *body_gm*
128+ :return: a name and the custom op definition, the name
129+ is used to cache the custom op
130+ """
106131 global _DISPATCHER
107132 srank = "_" .join ("x" .join (map (str , s .shape )) for s in body_outputs )
108133 sred = "x" .join (map (str , reduction_dim )) if reduction_dim else ""
@@ -120,7 +145,7 @@ def make_custom_loop_for(
120145 custom_def ._abstract_fn = lambda * _args , _o = body_outputs : (
121146 tuple ([torch .empty_like (s ) for s in _o ]) if len (_o ) > 1 else torch .empty_like (_o [0 ])
122147 )
123- onx = convert_into_onnx (body_gm , args , body_mutated_inputs , body_outputs )
148+ onx = convert_into_onnx (body_gm , args )
124149 to_register = (
125150 custom_def ,
126151 onx ,
@@ -149,15 +174,26 @@ def convert_custom_loop_into_onnx(
149174 reduction_dim : Optional [Tuple [int , ...]] = None ,
150175 name : str = "loop_for" ,
151176) -> Union [str , Tuple [str , ...]]:
177+ """
178+ Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
179+
180+ :param g: GreaphBuilder
181+ :param sts: if not defined, torch does not know the output shapes
182+ :param outputs: output names
183+ :param args: input argument known at export time
184+ :param body: GraphProto, the loop body
185+ :param reduction_dim: the dimension to follow when aggregating the
186+ list of tensors after the loop ran
187+ :param name: to give the onnx nodes a name
188+ :return: output names
189+ """
152190 graph = body .graph if isinstance (body , onnx .ModelProto ) else body
153191 assert isinstance (
154192 graph , onnx .GraphProto
155193 ), f"Unexpected type { type (body )} for body{ g .get_debug_msg ()} "
156- assert len (outputs ) == len (graph .output ), (
157- f"Length mismatch, expecting { len (outputs )} but got "
158- f"{ len (graph .output )} , \n --\n { pretty_onnx (body )} "
159- f"{ g .get_debug_msg ()} "
160- )
194+ assert len (outputs ) == 1 , f"Only one outputs is expected but outputs={ outputs !r} "
195+ if len (graph .output ) != 1 :
196+ outputs = [f"{ outputs [0 ]} #{ i } " for i in range (len (graph .output ))]
161197 input_names = [i .name for i in graph .input ]
162198 inputs = [
163199 * graph .input [:1 ],
@@ -218,11 +254,20 @@ def convert_custom_loop_into_onnx(
218254 for i , o in enumerate (outputs ):
219255 g .set_type (o , graph .output [i ].type .tensor_type .elem_type )
220256 g .set_rank (o , len (graph .output [i ].type .tensor_type .shape .dims ))
221- return tuple (outputs ) if len (outputs ) > 1 else outputs [0 ]
257+ return outputs if len (outputs ) > 1 else outputs [0 ]
258+
222259
260+ def convert_into_onnx (
261+ body_gm : torch .fx .GraphModule , args : List [torch .Tensor ]
262+ ) -> onnx .ModelProto :
263+ """
264+ Converts a torch.fx.GraphModule into ONNX.
265+ It returns a ModelProto.
223266
224- def convert_into_onnx (body_gm , args , body_mutated_inputs , body_outputs ):
225- """Converts into ONNX."""
267+ :param body_gm: a torch.fx.GraphModule
268+ :param args: arguments known at export time
269+ :return: a ModelProto
270+ """
226271 # This does not work with onnx-dynamo.
227272 # opset still needs to be defined
228273 container = to_onnx (body_gm , args , exporter = "custom" )
@@ -239,6 +284,20 @@ def loop_for(
239284 High operators used to easily export a loop in ONNX.
240285 Does not fully work with :func:`torch.export.export`,
241286 it does replaces a custom op with a loop operator afterwards.
287+ Every iteration produces tensors, all of them are gathered
288+ into lists, all these lists are concatenated into tensors.
289+
290+ :param n_iter: number of iterations, it can be fixed on
291+ variable, in that case it should a tensor with no dimension
292+ :param body_fn: function body, takes only tensors and returns
293+ only tensors, the first argument is the iteration number
294+ in a tensor with no dimension, all the others
295+ are not changed during the loop
296+ :param args: the available tensors at every loop
297+ :param reduction_dim: the loop aggregated the results into list,
298+ one of each output, each of them is concatenated into one
299+ tensor along one dimension, by default, it is the first
300+ dimension, but it can be defined otherwise
242301
243302 .. runpython::
244303 :showcode:
@@ -271,6 +330,52 @@ def body(i, x):
271330 use_control_flow_dispatcher=True,
272331 ).model_proto
273332
333+ sess = onnxruntime.InferenceSession(
334+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
335+ )
336+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
337+ print("got:", got)
338+
339+
340+ # The loop is exported as a custom ops.
341+ ep = torch.export.export(
342+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
343+ )
344+ print(ep)
345+
346+ Another example with two outputs:
347+
348+ .. runpython::
349+ :showcode:
350+
351+ import torch
352+ import onnxruntime
353+ from onnx_diagnostic.export.api import to_onnx
354+ from onnx_diagnostic.export.control_flow import loop_for
355+
356+
357+ class Model(torch.nn.Module):
358+ def forward(self, n_iter, x):
359+ def body(i, x):
360+ return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
361+
362+ two = loop_for(n_iter, body, (x,))
363+ return two[0] + two[1]
364+
365+
366+ model = Model()
367+ n_iter = torch.tensor(4, dtype=torch.int64)
368+ x = torch.arange(10, dtype=torch.float32)
369+ expected = model(n_iter, x)
370+ print("expected:", expected)
371+
372+ onx = to_onnx(
373+ model,
374+ (n_iter, x),
375+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
376+ exporter="custom",
377+ use_control_flow_dispatcher=True,
378+ ).model_proto
274379
275380 sess = onnxruntime.InferenceSession(
276381 onx.SerializeToString(), providers=["CPUExecutionProvider"]
@@ -285,6 +390,52 @@ def body(i, x):
285390 )
286391 print(ep)
287392
393+ A last example with ``reduction_dim``:
394+
395+ .. runpython::
396+ :showcode:
397+
398+ import torch
399+ import onnxruntime
400+ from onnx_diagnostic.export.api import to_onnx
401+ from onnx_diagnostic.export.control_flow import loop_for
402+
403+
404+ class Model(torch.nn.Module):
405+ def forward(self, n_iter, x):
406+ def body(i, x):
407+ return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
408+
409+ two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
410+ return two[0] + two[1].T
411+
412+
413+ model = Model()
414+ n_iter = torch.tensor(4, dtype=torch.int64)
415+ x = torch.arange(10, dtype=torch.float32)
416+ expected = model(n_iter, x)
417+ print("expected:", expected)
418+
419+ onx = to_onnx(
420+ model,
421+ (n_iter, x),
422+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
423+ exporter="custom",
424+ use_control_flow_dispatcher=True,
425+ ).model_proto
426+
427+ sess = onnxruntime.InferenceSession(
428+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
429+ )
430+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
431+ print("got:", got)
432+
433+
434+ # The loop is exported as a custom ops.
435+ ep = torch.export.export(
436+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
437+ )
438+ print(ep)
288439 """
289440 assert args , "The function should have at least one arg."
290441 assert (
0 commit comments