diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 7dbb847e..85f20dbf 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -20,7 +20,7 @@ Change Logs * :pr:`303`: fix inputs for summarization, feature extraction tasks * :pr:`302`: adds helpers to analyse onnxruntime profiling -* :pr:`297`: experiment around a higher ops ``loop_for`` +* :pr:`297`: experiment around a higher ops ``loop_for_onnx`` * :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models 0.8.1 diff --git a/_doc/api/export/control_flow.rst b/_doc/api/export/control_flow.rst deleted file mode 100644 index 4b0aba66..00000000 --- a/_doc/api/export/control_flow.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.export.control_flow -=================================== - -.. automodule:: onnx_diagnostic.export.control_flow - :members: diff --git a/_doc/api/export/control_flow_onnx.rst b/_doc/api/export/control_flow_onnx.rst new file mode 100644 index 00000000..8a6c41bd --- /dev/null +++ b/_doc/api/export/control_flow_onnx.rst @@ -0,0 +1,6 @@ + +onnx_diagnostic.export.control_flow_onnx +======================================== + +.. automodule:: onnx_diagnostic.export.control_flow_onnx + :members: diff --git a/_doc/api/export/index.rst b/_doc/api/export/index.rst index ce546228..8c806fe6 100644 --- a/_doc/api/export/index.rst +++ b/_doc/api/export/index.rst @@ -6,7 +6,7 @@ onnx_diagnostic.export :caption: modules api - control_flow + control_flow_onnx dynamic_shapes onnx_plug shape_helper diff --git a/_unittests/ut_export/test_control_flow.py b/_unittests/ut_export/test_control_flow_onnx.py similarity index 87% rename from _unittests/ut_export/test_control_flow.py rename to _unittests/ut_export/test_control_flow_onnx.py index 6e10c708..4791f22b 100644 --- a/_unittests/ut_export/test_control_flow.py +++ b/_unittests/ut_export/test_control_flow_onnx.py @@ -4,12 +4,15 @@ from onnxscript import script, FLOAT, INT64 from onnxscript import opset18 as op from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test -from onnx_diagnostic.export.control_flow import enable_code_export_control_flow, loop_for +from onnx_diagnostic.export.control_flow_onnx import ( + enable_code_export_control_flow, + loop_for_onnx, +) from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r from onnx_diagnostic.export.api import to_onnx -class TestControlFlow(ExtTestCase): +class TestControlFlowOnnx(ExtTestCase): @never_test() def test_loop_one_research(self): class Model(torch.nn.Module): @@ -54,7 +57,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1) - return loop_for(n_iter, body, (x,)) + return loop_for_onnx(n_iter, body, (x,)) model = Model() n_iter = torch.tensor(4, dtype=torch.int64) @@ -67,7 +70,7 @@ def body(i, x): model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) self.assertIn( - "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_", + "torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_L_Model_forward_L_body_", str(ep), ) @@ -88,7 +91,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1) - return loop_for(n_iter, body, (x,)) + return loop_for_onnx(n_iter, body, (x,)) model = Model() n_iter = torch.tensor(4, dtype=torch.int64) @@ -101,7 +104,7 @@ def body(i, x): model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) self.assertIn( - "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_", + "torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_different_opset_L_Model_forward_L_body_", str(ep), ) @@ -125,7 +128,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1 - res = loop_for(n_iter, body, (x,)) + res = loop_for_onnx(n_iter, body, (x,)) return res[0] + res[1] model = Model() @@ -139,7 +142,7 @@ def body(i, x): model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) self.assertIn( - "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_", + "torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_L_Model_forward_L_body_", str(ep), ) @@ -160,7 +163,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1 - res = loop_for(n_iter, body, (x,), reduction_dim=[0, 1]) + res = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1]) return res[0] + res[1].T model = Model() @@ -174,7 +177,7 @@ def body(i, x): model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) self.assertIn( - "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_", + "torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_", str(ep), ) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 8ee7a84f..b5a651cd 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -43,7 +43,7 @@ def to_onnx( :param save_ep: saves the exported program :param optimize: optimizes the model :param use_control_flow_dispatcher: use the dispatcher created to supported - custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`) + custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`) :param onnx_plugs: the code was modified to replace some parts with onnx translation :param inline: inline local functions :return: the output of the selected exporter, usually a structure including @@ -62,7 +62,7 @@ def to_onnx( ) Some examples using control flows are available in - :func:`onnx_diagnostic.export.control_flow.loop_for` or + :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`. """ if exporter_kwargs and "inline" in exporter_kwargs: @@ -86,7 +86,7 @@ def to_onnx( from experimental_experiment.torch_interpreter import Dispatcher if use_control_flow_dispatcher: - from .control_flow import create_global_dispatcher + from .control_flow_onnx import create_global_dispatcher control_flow_dispatcher = create_global_dispatcher() else: diff --git a/onnx_diagnostic/export/control_flow.py b/onnx_diagnostic/export/control_flow_onnx.py similarity index 92% rename from onnx_diagnostic/export/control_flow.py rename to onnx_diagnostic/export/control_flow_onnx.py index 21814084..18e09908 100644 --- a/onnx_diagnostic/export/control_flow.py +++ b/onnx_diagnostic/export/control_flow_onnx.py @@ -55,7 +55,7 @@ def is_exporting() -> bool: return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling() -def _loop_for_fn(n_iter, body_fn, reduction_dim, args): +def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args): """ Python implementation of the loop. @@ -103,7 +103,7 @@ def _loop_for_fn(n_iter, body_fn, reduction_dim, args): return tuple(final) if len(final) > 1 else final[0] -def make_custom_loop_for( +def make_custom_loop_for_onnx( n_iter: torch.Tensor, body_fn: Callable, reduction_dim: Optional[Sequence[int]], @@ -139,7 +139,7 @@ def make_custom_loop_for( .replace("", "l") .replace(".", "_") ) - name = f"loop_for_{full_name}_{srank}_{sred}" + name = f"loop_for_onnx_{full_name}_{srank}_{sred}" if name in _REGISTERED_SCHEMA: return name, _REGISTERED_SCHEMA[name][0] sig = inspect.signature(body_fn) @@ -197,10 +197,10 @@ def convert_custom_loop_into_onnx( *args: str, body_callable: Callable[..., onnx.ModelProto], reduction_dim: Optional[Sequence[int]] = None, - name: str = "loop_for", + name: str = "loop_for_onnx", ) -> Union[str, List[str]]: """ - Converts a custom op ``higher_ops::loop_for...`` into e sequence of node. + Converts a custom op ``higher_ops::loop_for_onnx...`` into e sequence of node. :param g: GreaphBuilder :param sts: if not defined, torch does not know the output shapes @@ -265,9 +265,19 @@ def convert_custom_loop_into_onnx( nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer ) - sequences = [g.op.SequenceEmpty() for _ in outputs] + itypes = [ + graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type + for i in range(1, len(graph.output)) + ] + assert len(outputs) == len( + itypes + ), f"Length mismatch between outputs={outputs} and graph.output={graph.output}" + assert ( + 0 not in itypes + ), f"Undefined types are not allowed in itype={itypes}, graph.output={graph.output}" + sequences = [g.op.SequenceEmpty(dtype=itype) for itype in itypes] - outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))] + outloop = [g.unique_name(f"loop_for_onnx{i}") for i in range(len(sequences))] for i, s in enumerate(sequences): g.set_sequence(s, graph.output[i].type.tensor_type.elem_type) @@ -285,8 +295,10 @@ def convert_custom_loop_into_onnx( ] if not sts: for i, o in enumerate(outputs): - g.set_type(o, graph.output[i].type.tensor_type.elem_type) - g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims)) + g.set_type(o, graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type) + g.set_rank( + o, len(graph.output[i].type.sequence_type.elem_type.tensor_type.shape.dims) + ) return outputs if len(outputs) > 1 else outputs[0] @@ -321,7 +333,7 @@ def convert_into_onnx( return container.model_proto -def loop_for( +def loop_for_onnx( n_iter: Union[torch.SymInt, torch.Tensor], body_fn: Callable[..., Tuple[torch.Tensor]], args: Sequence[torch.Tensor], @@ -352,7 +364,7 @@ def loop_for( import torch import onnxruntime from onnx_diagnostic.export.api import to_onnx - from onnx_diagnostic.export.control_flow import loop_for + from onnx_diagnostic.export.control_flow import loop_for_onnx class Model(torch.nn.Module): @@ -360,7 +372,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1) - return loop_for(n_iter, body, (x,)) + return loop_for_onnx(n_iter, body, (x,)) model = Model() @@ -398,7 +410,7 @@ def body(i, x): import torch import onnxruntime from onnx_diagnostic.export.api import to_onnx - from onnx_diagnostic.export.control_flow import loop_for + from onnx_diagnostic.export.control_flow import loop_for_onnx class Model(torch.nn.Module): @@ -406,7 +418,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1 - two = loop_for(n_iter, body, (x,)) + two = loop_for_onnx(n_iter, body, (x,)) return two[0] + two[1] @@ -445,7 +457,7 @@ def body(i, x): import torch import onnxruntime from onnx_diagnostic.export.api import to_onnx - from onnx_diagnostic.export.control_flow import loop_for + from onnx_diagnostic.export.control_flow import loop_for_onnx class Model(torch.nn.Module): @@ -453,7 +465,7 @@ def forward(self, n_iter, x): def body(i, x): return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1 - two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1]) + two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1]) return two[0] + two[1].T @@ -501,7 +513,7 @@ def body(i, x): body_mutated_inputs, body_outputs, ) = check_input_alias_and_mutation_return_outputs(body_gm) - name, _custom_ops = make_custom_loop_for( + name, _custom_ops = make_custom_loop_for_onnx( n_iter, body_fn, reduction_dim, @@ -513,4 +525,4 @@ def body(i, x): fct = getattr(torch.ops.onnx_higher_ops, name) return fct(n_iter, *args) - return _loop_for_fn(n_iter, body_fn, reduction_dim, args) + return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args) diff --git a/onnx_diagnostic/export/control_flow_research.py b/onnx_diagnostic/export/control_flow_research.py index c56135d5..261d0a5a 100644 --- a/onnx_diagnostic/export/control_flow_research.py +++ b/onnx_diagnostic/export/control_flow_research.py @@ -14,7 +14,7 @@ ) from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree from torch.utils._python_dispatch import _get_current_dispatch_mode -from .control_flow import _loop_for_fn +from .control_flow_onnx import _loop_for_onnx_fn class SimpleLoopForOp(HigherOrderOperator): @@ -66,7 +66,7 @@ def simple_loop_for( return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands)) if isinstance(n_iter, (bool, int, float)): - return _loop_for_fn(body_fn, n_iter, None, *operands) + return _loop_for_onnx_fn(body_fn, n_iter, None, *operands) def _validate_input(n_iter, body_fn, operands): assert isinstance( @@ -127,7 +127,7 @@ def loop_for_op_dense(n_iter, body_fn, operands): ), f"Dense implementation operands must be a list of tensors and ints {operands}" mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return _loop_for_fn(body_fn, n_iter, None, *operands) + return _loop_for_onnx_fn(body_fn, n_iter, None, *operands) @simple_loop_for_op.py_impl(ProxyTorchDispatchMode) diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index d9e526e0..d1caa0c9 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -159,7 +159,9 @@ def append_output_sequence( """ if not tensors: # empty list - self.nodes.append(oh.make_node("SequenceEmpty", [], [name])) + self.nodes.append( + oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT) + ) tensor_type_proto = oh.make_tensor_type_proto( elem_type=TensorProto.FLOAT, shape=None )