Skip to content

Commit f9d798e

Browse files
authored
rename loop_for into loop_for_onnx (#327)
* rename loop_for into loop_for_onnx * renaming * fix loops
1 parent 7e7c2e8 commit f9d798e

File tree

9 files changed

+60
-43
lines changed

9 files changed

+60
-43
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Change Logs
2020

2121
* :pr:`303`: fix inputs for summarization, feature extraction tasks
2222
* :pr:`302`: adds helpers to analyse onnxruntime profiling
23-
* :pr:`297`: experiment around a higher ops ``loop_for``
23+
* :pr:`297`: experiment around a higher ops ``loop_for_onnx``
2424
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models
2525

2626
0.8.1

_doc/api/export/control_flow.rst

Lines changed: 0 additions & 6 deletions
This file was deleted.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.export.control_flow_onnx
3+
========================================
4+
5+
.. automodule:: onnx_diagnostic.export.control_flow_onnx
6+
:members:

_doc/api/export/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ onnx_diagnostic.export
66
:caption: modules
77

88
api
9-
control_flow
9+
control_flow_onnx
1010
dynamic_shapes
1111
onnx_plug
1212
shape_helper

_unittests/ut_export/test_control_flow.py renamed to _unittests/ut_export/test_control_flow_onnx.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from onnxscript import script, FLOAT, INT64
55
from onnxscript import opset18 as op
66
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test
7-
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
7+
from onnx_diagnostic.export.control_flow_onnx import (
8+
enable_code_export_control_flow,
9+
loop_for_onnx,
10+
)
811
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
912
from onnx_diagnostic.export.api import to_onnx
1013

1114

12-
class TestControlFlow(ExtTestCase):
15+
class TestControlFlowOnnx(ExtTestCase):
1316
@never_test()
1417
def test_loop_one_research(self):
1518
class Model(torch.nn.Module):
@@ -54,7 +57,7 @@ def forward(self, n_iter, x):
5457
def body(i, x):
5558
return x[: i.item() + 1].unsqueeze(1)
5659

57-
return loop_for(n_iter, body, (x,))
60+
return loop_for_onnx(n_iter, body, (x,))
5861

5962
model = Model()
6063
n_iter = torch.tensor(4, dtype=torch.int64)
@@ -67,7 +70,7 @@ def body(i, x):
6770
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
6871
)
6972
self.assertIn(
70-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
73+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_L_Model_forward_L_body_",
7174
str(ep),
7275
)
7376

@@ -88,7 +91,7 @@ def forward(self, n_iter, x):
8891
def body(i, x):
8992
return x[: i.item() + 1].unsqueeze(1)
9093

91-
return loop_for(n_iter, body, (x,))
94+
return loop_for_onnx(n_iter, body, (x,))
9295

9396
model = Model()
9497
n_iter = torch.tensor(4, dtype=torch.int64)
@@ -101,7 +104,7 @@ def body(i, x):
101104
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
102105
)
103106
self.assertIn(
104-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
107+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
105108
str(ep),
106109
)
107110

@@ -125,7 +128,7 @@ def forward(self, n_iter, x):
125128
def body(i, x):
126129
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
127130

128-
res = loop_for(n_iter, body, (x,))
131+
res = loop_for_onnx(n_iter, body, (x,))
129132
return res[0] + res[1]
130133

131134
model = Model()
@@ -139,7 +142,7 @@ def body(i, x):
139142
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
140143
)
141144
self.assertIn(
142-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
145+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_L_Model_forward_L_body_",
143146
str(ep),
144147
)
145148

@@ -160,7 +163,7 @@ def forward(self, n_iter, x):
160163
def body(i, x):
161164
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
162165

163-
res = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
166+
res = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
164167
return res[0] + res[1].T
165168

166169
model = Model()
@@ -174,7 +177,7 @@ def body(i, x):
174177
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
175178
)
176179
self.assertIn(
177-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
180+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
178181
str(ep),
179182
)
180183

onnx_diagnostic/export/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def to_onnx(
4343
:param save_ep: saves the exported program
4444
:param optimize: optimizes the model
4545
:param use_control_flow_dispatcher: use the dispatcher created to supported
46-
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
46+
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
4747
:param onnx_plugs: the code was modified to replace some parts with onnx translation
4848
:param inline: inline local functions
4949
:return: the output of the selected exporter, usually a structure including
@@ -62,7 +62,7 @@ def to_onnx(
6262
)
6363
6464
Some examples using control flows are available in
65-
:func:`onnx_diagnostic.export.control_flow.loop_for` or
65+
:func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
6666
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
6767
"""
6868
if exporter_kwargs and "inline" in exporter_kwargs:
@@ -86,7 +86,7 @@ def to_onnx(
8686
from experimental_experiment.torch_interpreter import Dispatcher
8787

8888
if use_control_flow_dispatcher:
89-
from .control_flow import create_global_dispatcher
89+
from .control_flow_onnx import create_global_dispatcher
9090

9191
control_flow_dispatcher = create_global_dispatcher()
9292
else:

onnx_diagnostic/export/control_flow.py renamed to onnx_diagnostic/export/control_flow_onnx.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def is_exporting() -> bool:
5555
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
5656

5757

58-
def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
58+
def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
5959
"""
6060
Python implementation of the loop.
6161
@@ -103,7 +103,7 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
103103
return tuple(final) if len(final) > 1 else final[0]
104104

105105

106-
def make_custom_loop_for(
106+
def make_custom_loop_for_onnx(
107107
n_iter: torch.Tensor,
108108
body_fn: Callable,
109109
reduction_dim: Optional[Sequence[int]],
@@ -139,7 +139,7 @@ def make_custom_loop_for(
139139
.replace("<lambda>", "l")
140140
.replace(".", "_")
141141
)
142-
name = f"loop_for_{full_name}_{srank}_{sred}"
142+
name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
143143
if name in _REGISTERED_SCHEMA:
144144
return name, _REGISTERED_SCHEMA[name][0]
145145
sig = inspect.signature(body_fn)
@@ -197,10 +197,10 @@ def convert_custom_loop_into_onnx(
197197
*args: str,
198198
body_callable: Callable[..., onnx.ModelProto],
199199
reduction_dim: Optional[Sequence[int]] = None,
200-
name: str = "loop_for",
200+
name: str = "loop_for_onnx",
201201
) -> Union[str, List[str]]:
202202
"""
203-
Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
203+
Converts a custom op ``higher_ops::loop_for_onnx...`` into e sequence of node.
204204
205205
:param g: GreaphBuilder
206206
:param sts: if not defined, torch does not know the output shapes
@@ -265,9 +265,19 @@ def convert_custom_loop_into_onnx(
265265
nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
266266
)
267267

268-
sequences = [g.op.SequenceEmpty() for _ in outputs]
268+
itypes = [
269+
graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type
270+
for i in range(1, len(graph.output))
271+
]
272+
assert len(outputs) == len(
273+
itypes
274+
), f"Length mismatch between outputs={outputs} and graph.output={graph.output}"
275+
assert (
276+
0 not in itypes
277+
), f"Undefined types are not allowed in itype={itypes}, graph.output={graph.output}"
278+
sequences = [g.op.SequenceEmpty(dtype=itype) for itype in itypes]
269279

270-
outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))]
280+
outloop = [g.unique_name(f"loop_for_onnx{i}") for i in range(len(sequences))]
271281

272282
for i, s in enumerate(sequences):
273283
g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
@@ -285,8 +295,10 @@ def convert_custom_loop_into_onnx(
285295
]
286296
if not sts:
287297
for i, o in enumerate(outputs):
288-
g.set_type(o, graph.output[i].type.tensor_type.elem_type)
289-
g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
298+
g.set_type(o, graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type)
299+
g.set_rank(
300+
o, len(graph.output[i].type.sequence_type.elem_type.tensor_type.shape.dims)
301+
)
290302
return outputs if len(outputs) > 1 else outputs[0]
291303

292304

@@ -321,7 +333,7 @@ def convert_into_onnx(
321333
return container.model_proto
322334

323335

324-
def loop_for(
336+
def loop_for_onnx(
325337
n_iter: Union[torch.SymInt, torch.Tensor],
326338
body_fn: Callable[..., Tuple[torch.Tensor]],
327339
args: Sequence[torch.Tensor],
@@ -352,15 +364,15 @@ def loop_for(
352364
import torch
353365
import onnxruntime
354366
from onnx_diagnostic.export.api import to_onnx
355-
from onnx_diagnostic.export.control_flow import loop_for
367+
from onnx_diagnostic.export.control_flow import loop_for_onnx
356368
357369
358370
class Model(torch.nn.Module):
359371
def forward(self, n_iter, x):
360372
def body(i, x):
361373
return x[: i.item() + 1].unsqueeze(1)
362374
363-
return loop_for(n_iter, body, (x,))
375+
return loop_for_onnx(n_iter, body, (x,))
364376
365377
366378
model = Model()
@@ -398,15 +410,15 @@ def body(i, x):
398410
import torch
399411
import onnxruntime
400412
from onnx_diagnostic.export.api import to_onnx
401-
from onnx_diagnostic.export.control_flow import loop_for
413+
from onnx_diagnostic.export.control_flow import loop_for_onnx
402414
403415
404416
class Model(torch.nn.Module):
405417
def forward(self, n_iter, x):
406418
def body(i, x):
407419
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
408420
409-
two = loop_for(n_iter, body, (x,))
421+
two = loop_for_onnx(n_iter, body, (x,))
410422
return two[0] + two[1]
411423
412424
@@ -445,15 +457,15 @@ def body(i, x):
445457
import torch
446458
import onnxruntime
447459
from onnx_diagnostic.export.api import to_onnx
448-
from onnx_diagnostic.export.control_flow import loop_for
460+
from onnx_diagnostic.export.control_flow import loop_for_onnx
449461
450462
451463
class Model(torch.nn.Module):
452464
def forward(self, n_iter, x):
453465
def body(i, x):
454466
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
455467
456-
two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
468+
two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
457469
return two[0] + two[1].T
458470
459471
@@ -501,7 +513,7 @@ def body(i, x):
501513
body_mutated_inputs,
502514
body_outputs,
503515
) = check_input_alias_and_mutation_return_outputs(body_gm)
504-
name, _custom_ops = make_custom_loop_for(
516+
name, _custom_ops = make_custom_loop_for_onnx(
505517
n_iter,
506518
body_fn,
507519
reduction_dim,
@@ -513,4 +525,4 @@ def body(i, x):
513525
fct = getattr(torch.ops.onnx_higher_ops, name)
514526
return fct(n_iter, *args)
515527

516-
return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
528+
return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)

onnx_diagnostic/export/control_flow_research.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
1616
from torch.utils._python_dispatch import _get_current_dispatch_mode
17-
from .control_flow import _loop_for_fn
17+
from .control_flow_onnx import _loop_for_onnx_fn
1818

1919

2020
class SimpleLoopForOp(HigherOrderOperator):
@@ -66,7 +66,7 @@ def simple_loop_for(
6666
return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands))
6767

6868
if isinstance(n_iter, (bool, int, float)):
69-
return _loop_for_fn(body_fn, n_iter, None, *operands)
69+
return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
7070

7171
def _validate_input(n_iter, body_fn, operands):
7272
assert isinstance(
@@ -127,7 +127,7 @@ def loop_for_op_dense(n_iter, body_fn, operands):
127127
), f"Dense implementation operands must be a list of tensors and ints {operands}"
128128
mode = _get_current_dispatch_mode()
129129
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
130-
return _loop_for_fn(body_fn, n_iter, None, *operands)
130+
return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
131131

132132

133133
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def append_output_sequence(
159159
"""
160160
if not tensors:
161161
# empty list
162-
self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
162+
self.nodes.append(
163+
oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT)
164+
)
163165
tensor_type_proto = oh.make_tensor_type_proto(
164166
elem_type=TensorProto.FLOAT, shape=None
165167
)

0 commit comments

Comments
 (0)