@@ -55,7 +55,7 @@ def is_exporting() -> bool:
5555 return _TEST_EXPORT or torch .compiler .is_exporting () or torch .compiler .is_compiling ()
5656
5757
58- def _loop_for_fn (n_iter , body_fn , reduction_dim , args ):
58+ def _loop_for_onnx_fn (n_iter , body_fn , reduction_dim , args ):
5959 """
6060 Python implementation of the loop.
6161
@@ -103,7 +103,7 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
103103 return tuple (final ) if len (final ) > 1 else final [0 ]
104104
105105
106- def make_custom_loop_for (
106+ def make_custom_loop_for_onnx (
107107 n_iter : torch .Tensor ,
108108 body_fn : Callable ,
109109 reduction_dim : Optional [Sequence [int ]],
@@ -139,7 +139,7 @@ def make_custom_loop_for(
139139 .replace ("<lambda>" , "l" )
140140 .replace ("." , "_" )
141141 )
142- name = f"loop_for_ { full_name } _{ srank } _{ sred } "
142+ name = f"loop_for_onnx_ { full_name } _{ srank } _{ sred } "
143143 if name in _REGISTERED_SCHEMA :
144144 return name , _REGISTERED_SCHEMA [name ][0 ]
145145 sig = inspect .signature (body_fn )
@@ -197,10 +197,10 @@ def convert_custom_loop_into_onnx(
197197 * args : str ,
198198 body_callable : Callable [..., onnx .ModelProto ],
199199 reduction_dim : Optional [Sequence [int ]] = None ,
200- name : str = "loop_for " ,
200+ name : str = "loop_for_onnx " ,
201201) -> Union [str , List [str ]]:
202202 """
203- Converts a custom op ``higher_ops::loop_for ...`` into e sequence of node.
203+ Converts a custom op ``higher_ops::loop_for_onnx ...`` into e sequence of node.
204204
205205 :param g: GreaphBuilder
206206 :param sts: if not defined, torch does not know the output shapes
@@ -265,9 +265,19 @@ def convert_custom_loop_into_onnx(
265265 nodes , graph .name , inputs , graph_outputs , graph .initializer , graph .sparse_initializer
266266 )
267267
268- sequences = [g .op .SequenceEmpty () for _ in outputs ]
268+ itypes = [
269+ graph .output [i ].type .sequence_type .elem_type .tensor_type .elem_type
270+ for i in range (1 , len (graph .output ))
271+ ]
272+ assert len (outputs ) == len (
273+ itypes
274+ ), f"Length mismatch between outputs={ outputs } and graph.output={ graph .output } "
275+ assert (
276+ 0 not in itypes
277+ ), f"Undefined types are not allowed in itype={ itypes } , graph.output={ graph .output } "
278+ sequences = [g .op .SequenceEmpty (dtype = itype ) for itype in itypes ]
269279
270- outloop = [g .unique_name (f"loop_for { i } " ) for i in range (len (sequences ))]
280+ outloop = [g .unique_name (f"loop_for_onnx { i } " ) for i in range (len (sequences ))]
271281
272282 for i , s in enumerate (sequences ):
273283 g .set_sequence (s , graph .output [i ].type .tensor_type .elem_type )
@@ -285,8 +295,10 @@ def convert_custom_loop_into_onnx(
285295 ]
286296 if not sts :
287297 for i , o in enumerate (outputs ):
288- g .set_type (o , graph .output [i ].type .tensor_type .elem_type )
289- g .set_rank (o , len (graph .output [i ].type .tensor_type .shape .dims ))
298+ g .set_type (o , graph .output [i ].type .sequence_type .elem_type .tensor_type .elem_type )
299+ g .set_rank (
300+ o , len (graph .output [i ].type .sequence_type .elem_type .tensor_type .shape .dims )
301+ )
290302 return outputs if len (outputs ) > 1 else outputs [0 ]
291303
292304
@@ -321,7 +333,7 @@ def convert_into_onnx(
321333 return container .model_proto
322334
323335
324- def loop_for (
336+ def loop_for_onnx (
325337 n_iter : Union [torch .SymInt , torch .Tensor ],
326338 body_fn : Callable [..., Tuple [torch .Tensor ]],
327339 args : Sequence [torch .Tensor ],
@@ -352,15 +364,15 @@ def loop_for(
352364 import torch
353365 import onnxruntime
354366 from onnx_diagnostic.export.api import to_onnx
355- from onnx_diagnostic.export.control_flow import loop_for
367+ from onnx_diagnostic.export.control_flow import loop_for_onnx
356368
357369
358370 class Model(torch.nn.Module):
359371 def forward(self, n_iter, x):
360372 def body(i, x):
361373 return x[: i.item() + 1].unsqueeze(1)
362374
363- return loop_for (n_iter, body, (x,))
375+ return loop_for_onnx (n_iter, body, (x,))
364376
365377
366378 model = Model()
@@ -398,15 +410,15 @@ def body(i, x):
398410 import torch
399411 import onnxruntime
400412 from onnx_diagnostic.export.api import to_onnx
401- from onnx_diagnostic.export.control_flow import loop_for
413+ from onnx_diagnostic.export.control_flow import loop_for_onnx
402414
403415
404416 class Model(torch.nn.Module):
405417 def forward(self, n_iter, x):
406418 def body(i, x):
407419 return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
408420
409- two = loop_for (n_iter, body, (x,))
421+ two = loop_for_onnx (n_iter, body, (x,))
410422 return two[0] + two[1]
411423
412424
@@ -445,15 +457,15 @@ def body(i, x):
445457 import torch
446458 import onnxruntime
447459 from onnx_diagnostic.export.api import to_onnx
448- from onnx_diagnostic.export.control_flow import loop_for
460+ from onnx_diagnostic.export.control_flow import loop_for_onnx
449461
450462
451463 class Model(torch.nn.Module):
452464 def forward(self, n_iter, x):
453465 def body(i, x):
454466 return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
455467
456- two = loop_for (n_iter, body, (x,), reduction_dim=[0, 1])
468+ two = loop_for_onnx (n_iter, body, (x,), reduction_dim=[0, 1])
457469 return two[0] + two[1].T
458470
459471
@@ -501,7 +513,7 @@ def body(i, x):
501513 body_mutated_inputs ,
502514 body_outputs ,
503515 ) = check_input_alias_and_mutation_return_outputs (body_gm )
504- name , _custom_ops = make_custom_loop_for (
516+ name , _custom_ops = make_custom_loop_for_onnx (
505517 n_iter ,
506518 body_fn ,
507519 reduction_dim ,
@@ -513,4 +525,4 @@ def body(i, x):
513525 fct = getattr (torch .ops .onnx_higher_ops , name )
514526 return fct (n_iter , * args )
515527
516- return _loop_for_fn (n_iter , body_fn , reduction_dim , args )
528+ return _loop_for_onnx_fn (n_iter , body_fn , reduction_dim , args )
0 commit comments