Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Change Logs

* :pr:`303`: fix inputs for summarization, feature extraction tasks
* :pr:`302`: adds helpers to analyse onnxruntime profiling
* :pr:`297`: experiment around a higher ops ``loop_for``
* :pr:`297`: experiment around a higher ops ``loop_for_onnx``
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models

0.8.1
Expand Down
6 changes: 0 additions & 6 deletions _doc/api/export/control_flow.rst

This file was deleted.

6 changes: 6 additions & 0 deletions _doc/api/export/control_flow_onnx.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

onnx_diagnostic.export.control_flow_onnx
========================================

.. automodule:: onnx_diagnostic.export.control_flow_onnx
:members:
2 changes: 1 addition & 1 deletion _doc/api/export/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ onnx_diagnostic.export
:caption: modules

api
control_flow
control_flow_onnx
dynamic_shapes
onnx_plug
shape_helper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from onnxscript import script, FLOAT, INT64
from onnxscript import opset18 as op
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test
from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for
from onnx_diagnostic.export.control_flow_onnx import (
enable_code_export_control_flow,
loop_for_onnx,
)
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
from onnx_diagnostic.export.api import to_onnx


class TestControlFlow(ExtTestCase):
class TestControlFlowOnnx(ExtTestCase):
@never_test()
def test_loop_one_research(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -54,7 +57,7 @@ def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1)

return loop_for(n_iter, body, (x,))
return loop_for_onnx(n_iter, body, (x,))

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
Expand All @@ -67,7 +70,7 @@ def body(i, x):
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_L_Model_forward_L_body_",
str(ep),
)

Expand All @@ -88,7 +91,7 @@ def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1)

return loop_for(n_iter, body, (x,))
return loop_for_onnx(n_iter, body, (x,))

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
Expand All @@ -101,7 +104,7 @@ def body(i, x):
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
str(ep),
)

Expand All @@ -125,7 +128,7 @@ def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1

res = loop_for(n_iter, body, (x,))
res = loop_for_onnx(n_iter, body, (x,))
return res[0] + res[1]

model = Model()
Expand All @@ -139,7 +142,7 @@ def body(i, x):
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_L_Model_forward_L_body_",
str(ep),
)

Expand All @@ -160,7 +163,7 @@ def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1

res = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
res = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
return res[0] + res[1].T

model = Model()
Expand All @@ -174,7 +177,7 @@ def body(i, x):
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
str(ep),
)

Expand Down
6 changes: 3 additions & 3 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def to_onnx(
:param save_ep: saves the exported program
:param optimize: optimizes the model
:param use_control_flow_dispatcher: use the dispatcher created to supported
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
:param onnx_plugs: the code was modified to replace some parts with onnx translation
:param inline: inline local functions
:return: the output of the selected exporter, usually a structure including
Expand All @@ -62,7 +62,7 @@ def to_onnx(
)

Some examples using control flows are available in
:func:`onnx_diagnostic.export.control_flow.loop_for` or
:func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
"""
if exporter_kwargs and "inline" in exporter_kwargs:
Expand All @@ -86,7 +86,7 @@ def to_onnx(
from experimental_experiment.torch_interpreter import Dispatcher

if use_control_flow_dispatcher:
from .control_flow import create_global_dispatcher
from .control_flow_onnx import create_global_dispatcher

control_flow_dispatcher = create_global_dispatcher()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def is_exporting() -> bool:
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()


def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
"""
Python implementation of the loop.

Expand Down Expand Up @@ -103,7 +103,7 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
return tuple(final) if len(final) > 1 else final[0]


def make_custom_loop_for(
def make_custom_loop_for_onnx(
n_iter: torch.Tensor,
body_fn: Callable,
reduction_dim: Optional[Sequence[int]],
Expand Down Expand Up @@ -139,7 +139,7 @@ def make_custom_loop_for(
.replace("<lambda>", "l")
.replace(".", "_")
)
name = f"loop_for_{full_name}_{srank}_{sred}"
name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
if name in _REGISTERED_SCHEMA:
return name, _REGISTERED_SCHEMA[name][0]
sig = inspect.signature(body_fn)
Expand Down Expand Up @@ -197,10 +197,10 @@ def convert_custom_loop_into_onnx(
*args: str,
body_callable: Callable[..., onnx.ModelProto],
reduction_dim: Optional[Sequence[int]] = None,
name: str = "loop_for",
name: str = "loop_for_onnx",
) -> Union[str, List[str]]:
"""
Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
Converts a custom op ``higher_ops::loop_for_onnx...`` into e sequence of node.

:param g: GreaphBuilder
:param sts: if not defined, torch does not know the output shapes
Expand Down Expand Up @@ -265,9 +265,19 @@ def convert_custom_loop_into_onnx(
nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
)

sequences = [g.op.SequenceEmpty() for _ in outputs]
itypes = [
graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type
for i in range(1, len(graph.output))
]
assert len(outputs) == len(
itypes
), f"Length mismatch between outputs={outputs} and graph.output={graph.output}"
assert (
0 not in itypes
), f"Undefined types are not allowed in itype={itypes}, graph.output={graph.output}"
sequences = [g.op.SequenceEmpty(dtype=itype) for itype in itypes]

outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))]
outloop = [g.unique_name(f"loop_for_onnx{i}") for i in range(len(sequences))]

for i, s in enumerate(sequences):
g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
Expand All @@ -285,8 +295,10 @@ def convert_custom_loop_into_onnx(
]
if not sts:
for i, o in enumerate(outputs):
g.set_type(o, graph.output[i].type.tensor_type.elem_type)
g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
g.set_type(o, graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type)
g.set_rank(
o, len(graph.output[i].type.sequence_type.elem_type.tensor_type.shape.dims)
)
return outputs if len(outputs) > 1 else outputs[0]


Expand Down Expand Up @@ -321,7 +333,7 @@ def convert_into_onnx(
return container.model_proto


def loop_for(
def loop_for_onnx(
n_iter: Union[torch.SymInt, torch.Tensor],
body_fn: Callable[..., Tuple[torch.Tensor]],
args: Sequence[torch.Tensor],
Expand Down Expand Up @@ -352,15 +364,15 @@ def loop_for(
import torch
import onnxruntime
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.export.control_flow import loop_for
from onnx_diagnostic.export.control_flow import loop_for_onnx


class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1)

return loop_for(n_iter, body, (x,))
return loop_for_onnx(n_iter, body, (x,))


model = Model()
Expand Down Expand Up @@ -398,15 +410,15 @@ def body(i, x):
import torch
import onnxruntime
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.export.control_flow import loop_for
from onnx_diagnostic.export.control_flow import loop_for_onnx


class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1

two = loop_for(n_iter, body, (x,))
two = loop_for_onnx(n_iter, body, (x,))
return two[0] + two[1]


Expand Down Expand Up @@ -445,15 +457,15 @@ def body(i, x):
import torch
import onnxruntime
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.export.control_flow import loop_for
from onnx_diagnostic.export.control_flow import loop_for_onnx


class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i, x):
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1

two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
return two[0] + two[1].T


Expand Down Expand Up @@ -501,7 +513,7 @@ def body(i, x):
body_mutated_inputs,
body_outputs,
) = check_input_alias_and_mutation_return_outputs(body_gm)
name, _custom_ops = make_custom_loop_for(
name, _custom_ops = make_custom_loop_for_onnx(
n_iter,
body_fn,
reduction_dim,
Expand All @@ -513,4 +525,4 @@ def body(i, x):
fct = getattr(torch.ops.onnx_higher_ops, name)
return fct(n_iter, *args)

return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)
6 changes: 3 additions & 3 deletions onnx_diagnostic/export/control_flow_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode
from .control_flow import _loop_for_fn
from .control_flow_onnx import _loop_for_onnx_fn


class SimpleLoopForOp(HigherOrderOperator):
Expand Down Expand Up @@ -66,7 +66,7 @@ def simple_loop_for(
return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands))

if isinstance(n_iter, (bool, int, float)):
return _loop_for_fn(body_fn, n_iter, None, *operands)
return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)

def _validate_input(n_iter, body_fn, operands):
assert isinstance(
Expand Down Expand Up @@ -127,7 +127,7 @@ def loop_for_op_dense(n_iter, body_fn, operands):
), f"Dense implementation operands must be a list of tensors and ints {operands}"
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return _loop_for_fn(body_fn, n_iter, None, *operands)
return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)


@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
Expand Down
4 changes: 3 additions & 1 deletion onnx_diagnostic/helpers/mini_onnx_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def append_output_sequence(
"""
if not tensors:
# empty list
self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
self.nodes.append(
oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT)
)
tensor_type_proto = oh.make_tensor_type_proto(
elem_type=TensorProto.FLOAT, shape=None
)
Expand Down
Loading