Skip to content

Commit 11139ec

Browse files
committed
rename loop_for into loop_for_onnx
1 parent 035782f commit 11139ec

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Change Logs
1919

2020
* :pr:`303`: fix inputs for summarization, feature extraction tasks
2121
* :pr:`302`: adds helpers to analyse onnxruntime profiling
22-
* :pr:`297`: experiment around a higher ops ``loop_for``
22+
* :pr:`297`: experiment around a higher ops ``loop_for_onnx``
2323
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models
2424

2525
0.8.1

_unittests/ut_export/test_control_flow.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from onnxscript import script, FLOAT, INT64
55
from onnxscript import opset18 as op
66
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test
7-
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
7+
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for_onnx
88
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
99
from onnx_diagnostic.export.api import to_onnx
1010

@@ -54,7 +54,7 @@ def forward(self, n_iter, x):
5454
def body(i, x):
5555
return x[: i.item() + 1].unsqueeze(1)
5656

57-
return loop_for(n_iter, body, (x,))
57+
return loop_for_onnx(n_iter, body, (x,))
5858

5959
model = Model()
6060
n_iter = torch.tensor(4, dtype=torch.int64)
@@ -67,7 +67,7 @@ def body(i, x):
6767
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
6868
)
6969
self.assertIn(
70-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
70+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
7171
str(ep),
7272
)
7373

@@ -88,7 +88,7 @@ def forward(self, n_iter, x):
8888
def body(i, x):
8989
return x[: i.item() + 1].unsqueeze(1)
9090

91-
return loop_for(n_iter, body, (x,))
91+
return loop_for_onnx(n_iter, body, (x,))
9292

9393
model = Model()
9494
n_iter = torch.tensor(4, dtype=torch.int64)
@@ -125,7 +125,7 @@ def forward(self, n_iter, x):
125125
def body(i, x):
126126
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
127127

128-
res = loop_for(n_iter, body, (x,))
128+
res = loop_for_onnx(n_iter, body, (x,))
129129
return res[0] + res[1]
130130

131131
model = Model()
@@ -139,7 +139,7 @@ def body(i, x):
139139
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
140140
)
141141
self.assertIn(
142-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
142+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
143143
str(ep),
144144
)
145145

@@ -160,7 +160,7 @@ def forward(self, n_iter, x):
160160
def body(i, x):
161161
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
162162

163-
res = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
163+
res = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
164164
return res[0] + res[1].T
165165

166166
model = Model()
@@ -174,7 +174,7 @@ def body(i, x):
174174
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
175175
)
176176
self.assertIn(
177-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
177+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
178178
str(ep),
179179
)
180180

onnx_diagnostic/export/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def to_onnx(
4343
:param save_ep: saves the exported program
4444
:param optimize: optimizes the model
4545
:param use_control_flow_dispatcher: use the dispatcher created to supported
46-
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
46+
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for_onnx`)
4747
:param onnx_plugs: the code was modified to replace some parts with onnx translation
4848
:param inline: inline local functions
4949
:return: the output of the selected exporter, usually a structure including
@@ -62,7 +62,7 @@ def to_onnx(
6262
)
6363
6464
Some examples using control flows are available in
65-
:func:`onnx_diagnostic.export.control_flow.loop_for` or
65+
:func:`onnx_diagnostic.export.control_flow.loop_for_onnx` or
6666
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
6767
"""
6868
if exporter_kwargs and "inline" in exporter_kwargs:

onnx_diagnostic/export/control_flow.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def is_exporting() -> bool:
5555
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
5656

5757

58-
def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
58+
def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
5959
"""
6060
Python implementation of the loop.
6161
@@ -103,7 +103,7 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
103103
return tuple(final) if len(final) > 1 else final[0]
104104

105105

106-
def make_custom_loop_for(
106+
def make_custom_loop_for_onnx(
107107
n_iter: torch.Tensor,
108108
body_fn: Callable,
109109
reduction_dim: Optional[Sequence[int]],
@@ -139,7 +139,7 @@ def make_custom_loop_for(
139139
.replace("<lambda>", "l")
140140
.replace(".", "_")
141141
)
142-
name = f"loop_for_{full_name}_{srank}_{sred}"
142+
name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
143143
if name in _REGISTERED_SCHEMA:
144144
return name, _REGISTERED_SCHEMA[name][0]
145145
sig = inspect.signature(body_fn)
@@ -197,10 +197,10 @@ def convert_custom_loop_into_onnx(
197197
*args: str,
198198
body_callable: Callable[..., onnx.ModelProto],
199199
reduction_dim: Optional[Sequence[int]] = None,
200-
name: str = "loop_for",
200+
name: str = "loop_for_onnx",
201201
) -> Union[str, List[str]]:
202202
"""
203-
Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
203+
Converts a custom op ``higher_ops::loop_for_onnx...`` into e sequence of node.
204204
205205
:param g: GreaphBuilder
206206
:param sts: if not defined, torch does not know the output shapes
@@ -267,7 +267,7 @@ def convert_custom_loop_into_onnx(
267267

268268
sequences = [g.op.SequenceEmpty() for _ in outputs]
269269

270-
outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))]
270+
outloop = [g.unique_name(f"loop_for_onnx{i}") for i in range(len(sequences))]
271271

272272
for i, s in enumerate(sequences):
273273
g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
@@ -321,7 +321,7 @@ def convert_into_onnx(
321321
return container.model_proto
322322

323323

324-
def loop_for(
324+
def loop_for_onnx(
325325
n_iter: Union[torch.SymInt, torch.Tensor],
326326
body_fn: Callable[..., Tuple[torch.Tensor]],
327327
args: Sequence[torch.Tensor],
@@ -352,15 +352,15 @@ def loop_for(
352352
import torch
353353
import onnxruntime
354354
from onnx_diagnostic.export.api import to_onnx
355-
from onnx_diagnostic.export.control_flow import loop_for
355+
from onnx_diagnostic.export.control_flow import loop_for_onnx
356356
357357
358358
class Model(torch.nn.Module):
359359
def forward(self, n_iter, x):
360360
def body(i, x):
361361
return x[: i.item() + 1].unsqueeze(1)
362362
363-
return loop_for(n_iter, body, (x,))
363+
return loop_for_onnx(n_iter, body, (x,))
364364
365365
366366
model = Model()
@@ -398,15 +398,15 @@ def body(i, x):
398398
import torch
399399
import onnxruntime
400400
from onnx_diagnostic.export.api import to_onnx
401-
from onnx_diagnostic.export.control_flow import loop_for
401+
from onnx_diagnostic.export.control_flow import loop_for_onnx
402402
403403
404404
class Model(torch.nn.Module):
405405
def forward(self, n_iter, x):
406406
def body(i, x):
407407
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
408408
409-
two = loop_for(n_iter, body, (x,))
409+
two = loop_for_onnx(n_iter, body, (x,))
410410
return two[0] + two[1]
411411
412412
@@ -445,15 +445,15 @@ def body(i, x):
445445
import torch
446446
import onnxruntime
447447
from onnx_diagnostic.export.api import to_onnx
448-
from onnx_diagnostic.export.control_flow import loop_for
448+
from onnx_diagnostic.export.control_flow import loop_for_onnx
449449
450450
451451
class Model(torch.nn.Module):
452452
def forward(self, n_iter, x):
453453
def body(i, x):
454454
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
455455
456-
two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
456+
two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
457457
return two[0] + two[1].T
458458
459459
@@ -501,7 +501,7 @@ def body(i, x):
501501
body_mutated_inputs,
502502
body_outputs,
503503
) = check_input_alias_and_mutation_return_outputs(body_gm)
504-
name, _custom_ops = make_custom_loop_for(
504+
name, _custom_ops = make_custom_loop_for_onnx(
505505
n_iter,
506506
body_fn,
507507
reduction_dim,
@@ -513,4 +513,4 @@ def body(i, x):
513513
fct = getattr(torch.ops.onnx_higher_ops, name)
514514
return fct(n_iter, *args)
515515

516-
return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
516+
return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)

0 commit comments

Comments
 (0)