@@ -55,7 +55,7 @@ def is_exporting() -> bool:
5555 return _TEST_EXPORT or torch .compiler .is_exporting () or torch .compiler .is_compiling ()
5656
5757
58- def _loop_for_fn (n_iter , body_fn , reduction_dim , args ):
58+ def _loop_for_onnx_fn (n_iter , body_fn , reduction_dim , args ):
5959 """
6060 Python implementation of the loop.
6161
@@ -103,7 +103,7 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
103103 return tuple (final ) if len (final ) > 1 else final [0 ]
104104
105105
106- def make_custom_loop_for (
106+ def make_custom_loop_for_onnx (
107107 n_iter : torch .Tensor ,
108108 body_fn : Callable ,
109109 reduction_dim : Optional [Sequence [int ]],
@@ -139,7 +139,7 @@ def make_custom_loop_for(
139139 .replace ("<lambda>" , "l" )
140140 .replace ("." , "_" )
141141 )
142- name = f"loop_for_ { full_name } _{ srank } _{ sred } "
142+ name = f"loop_for_onnx_ { full_name } _{ srank } _{ sred } "
143143 if name in _REGISTERED_SCHEMA :
144144 return name , _REGISTERED_SCHEMA [name ][0 ]
145145 sig = inspect .signature (body_fn )
@@ -197,10 +197,10 @@ def convert_custom_loop_into_onnx(
197197 * args : str ,
198198 body_callable : Callable [..., onnx .ModelProto ],
199199 reduction_dim : Optional [Sequence [int ]] = None ,
200- name : str = "loop_for " ,
200+ name : str = "loop_for_onnx " ,
201201) -> Union [str , List [str ]]:
202202 """
203- Converts a custom op ``higher_ops::loop_for ...`` into e sequence of node.
203+ Converts a custom op ``higher_ops::loop_for_onnx ...`` into e sequence of node.
204204
205205 :param g: GreaphBuilder
206206 :param sts: if not defined, torch does not know the output shapes
@@ -267,7 +267,7 @@ def convert_custom_loop_into_onnx(
267267
268268 sequences = [g .op .SequenceEmpty () for _ in outputs ]
269269
270- outloop = [g .unique_name (f"loop_for { i } " ) for i in range (len (sequences ))]
270+ outloop = [g .unique_name (f"loop_for_onnx { i } " ) for i in range (len (sequences ))]
271271
272272 for i , s in enumerate (sequences ):
273273 g .set_sequence (s , graph .output [i ].type .tensor_type .elem_type )
@@ -321,7 +321,7 @@ def convert_into_onnx(
321321 return container .model_proto
322322
323323
324- def loop_for (
324+ def loop_for_onnx (
325325 n_iter : Union [torch .SymInt , torch .Tensor ],
326326 body_fn : Callable [..., Tuple [torch .Tensor ]],
327327 args : Sequence [torch .Tensor ],
@@ -352,15 +352,15 @@ def loop_for(
352352 import torch
353353 import onnxruntime
354354 from onnx_diagnostic.export.api import to_onnx
355- from onnx_diagnostic.export.control_flow import loop_for
355+ from onnx_diagnostic.export.control_flow import loop_for_onnx
356356
357357
358358 class Model(torch.nn.Module):
359359 def forward(self, n_iter, x):
360360 def body(i, x):
361361 return x[: i.item() + 1].unsqueeze(1)
362362
363- return loop_for (n_iter, body, (x,))
363+ return loop_for_onnx (n_iter, body, (x,))
364364
365365
366366 model = Model()
@@ -398,15 +398,15 @@ def body(i, x):
398398 import torch
399399 import onnxruntime
400400 from onnx_diagnostic.export.api import to_onnx
401- from onnx_diagnostic.export.control_flow import loop_for
401+ from onnx_diagnostic.export.control_flow import loop_for_onnx
402402
403403
404404 class Model(torch.nn.Module):
405405 def forward(self, n_iter, x):
406406 def body(i, x):
407407 return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
408408
409- two = loop_for (n_iter, body, (x,))
409+ two = loop_for_onnx (n_iter, body, (x,))
410410 return two[0] + two[1]
411411
412412
@@ -445,15 +445,15 @@ def body(i, x):
445445 import torch
446446 import onnxruntime
447447 from onnx_diagnostic.export.api import to_onnx
448- from onnx_diagnostic.export.control_flow import loop_for
448+ from onnx_diagnostic.export.control_flow import loop_for_onnx
449449
450450
451451 class Model(torch.nn.Module):
452452 def forward(self, n_iter, x):
453453 def body(i, x):
454454 return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
455455
456- two = loop_for (n_iter, body, (x,), reduction_dim=[0, 1])
456+ two = loop_for_onnx (n_iter, body, (x,), reduction_dim=[0, 1])
457457 return two[0] + two[1].T
458458
459459
@@ -501,7 +501,7 @@ def body(i, x):
501501 body_mutated_inputs ,
502502 body_outputs ,
503503 ) = check_input_alias_and_mutation_return_outputs (body_gm )
504- name , _custom_ops = make_custom_loop_for (
504+ name , _custom_ops = make_custom_loop_for_onnx (
505505 n_iter ,
506506 body_fn ,
507507 reduction_dim ,
@@ -513,4 +513,4 @@ def body(i, x):
513513 fct = getattr (torch .ops .onnx_higher_ops , name )
514514 return fct (n_iter , * args )
515515
516- return _loop_for_fn (n_iter , body_fn , reduction_dim , args )
516+ return _loop_for_onnx_fn (n_iter , body_fn , reduction_dim , args )
0 commit comments