Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Change Logs
0.8.3
+++++

* :pr:`310`: split patches into multiple files
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
* :pr:`310`: splits patches into multiple files
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
* :pr:`304`, :pr:`306`: improves side-by-side comparison, creates command line sbs

Expand Down
1 change: 1 addition & 0 deletions _doc/api/export/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ onnx_diagnostic.export
api
control_flow
dynamic_shapes
onnx_plug
shape_helper
validate

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/export/onnx_plug.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.export.onnx_plug
================================

.. automodule:: onnx_diagnostic.export.onnx_plug
:members:
:no-undoc-members:
20 changes: 16 additions & 4 deletions _unittests/ut_export/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def body(i, x):
ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
str(ep),
)

onx = to_onnx(
model,
Expand Down Expand Up @@ -97,7 +100,10 @@ def body(i, x):
ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
str(ep),
)

onx = to_onnx(
model,
Expand Down Expand Up @@ -132,7 +138,10 @@ def body(i, x):
ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
str(ep),
)

onx = to_onnx(
model,
Expand Down Expand Up @@ -164,7 +173,10 @@ def body(i, x):
ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep))
self.assertIn(
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
str(ep),
)

onx = to_onnx(
model,
Expand Down
105 changes: 105 additions & 0 deletions _unittests/ut_export/test_onnx_plug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import unittest
import onnx.helper as oh
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
from onnx_diagnostic.export.api import to_onnx


class TestOnnxPlus(ExtTestCase):
def test_onnx_plug_verify(self):
def _test_customadd(x, y):
return x + y

def _test_customadd_shape(x, y):
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)

def make_function_proto():
return oh.make_function(
"onnx_plug",
"_test_customadd",
["x", "y"],
["z"],
[oh.make_node("Add", ["x", "y"], ["z"])],
opset_imports=[oh.make_opsetid("", 22)],
)

rep = EagerDirectReplacementWithOnnx(
_test_customadd, _test_customadd_shape, make_function_proto(), 2, 1
)

x = torch.randn((3, 4), dtype=torch.float32)
y = torch.randn((3, 1), dtype=torch.float32)
self.assertEqualArray(_test_customadd(x, y), x + y)
res = rep.verify(x, y)
self.assertEqualAny(res.eager_outputs, (x + y,))
self.assertEqual(len(res.diffs), 1)
self.assertEqual(res.diffs[0]["abs"], 0)

def test_onnx_plug_export(self):
def _test_customsub(x, y):
return x - y

def _test_customsub_shape(x, y):
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)

def make_function_proto():
return oh.make_function(
"onnx_plug",
"_test_customsub",
["x", "y"],
["z"],
[oh.make_node("Sub", ["x", "y"], ["z"])],
opset_imports=[oh.make_opsetid("", 22)],
)

class Model(torch.nn.Module):
def forward(self, x):
y = x.sum(axis=1, keepdim=True)
d = torch.ops.onnx_plug._test_customsub(x, y)
return torch.abs(d)

replacements = [
EagerDirectReplacementWithOnnx(
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1
)
]

x = torch.randn((3, 4), dtype=torch.float32)
model = Model()
expected = model(x)
ds = ({0: "d1", 1: "d2"},)
ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds))
self.assertIn("torch.ops.onnx_plug._test_customsub.default", str(ep))
got = ep.module()(x)
self.assertEqualArray(expected, got)

with self.subTest(exporter="custom"):
onx = to_onnx(
model,
(x,),
dynamic_shapes=ds,
exporter="custom",
onnx_plugs=replacements,
target_opset=22,
)
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))

if not has_torch("2.9"):
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
with self.subTest(exporter="onnx-dynamo"):
onx = to_onnx(
model,
(x,),
dynamic_shapes=ds,
exporter="onnx-dynamo",
onnx_plugs=replacements,
target_opset=22,
)
self.assert_onnx_disc(
"test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,)
)


if __name__ == "__main__":
unittest.main(verbosity=2)
88 changes: 81 additions & 7 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from .onnx_plug import EagerDirectReplacementWithOnnx


def to_onnx(
Expand All @@ -18,6 +19,8 @@ def to_onnx(
save_ep: Optional[str] = None,
optimize: bool = True,
use_control_flow_dispatcher: bool = False,
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
inline: bool = True,
) -> Any:
"""
Common API for exporters. By default, the models are optimized to use the
Expand All @@ -41,6 +44,8 @@ def to_onnx(
:param optimize: optimizes the model
:param use_control_flow_dispatcher: use the dispatcher created to supported
custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
:param onnx_plugs: the code was modified to replace some parts with onnx translation
:param inline: inline local functions
:return: the output of the selected exporter, usually a structure including
an onnx model

Expand All @@ -55,24 +60,73 @@ def to_onnx(
exporter=exporter,
filename=filename,
)

Some examples using control flows are available in
:func:`onnx_diagnostic.export.control_flow.loop_for` or
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
"""
if exporter_kwargs and "inline" in exporter_kwargs:
assert (
inline == exporter_kwargs["inline"]
), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
exporter_kwargs.pop("inline")
if exporter == "custom":
from experimental_experiment.torch_interpreter import (
to_onnx as _to_onnx,
ExportOptions,
)
from experimental_experiment.xbuilder import OptimizationOptions

if use_control_flow_dispatcher:
from .control_flow import create_global_dispatcher

dispatcher = create_global_dispatcher()

options = None
if exporter_kwargs is not None:
options = exporter_kwargs.pop("options", None)
if options is None:
options = OptimizationOptions(patterns="default+onnxruntime")
if onnx_plugs or use_control_flow_dispatcher:
from experimental_experiment.torch_interpreter import Dispatcher

if use_control_flow_dispatcher:
from .control_flow import create_global_dispatcher

control_flow_dispatcher = create_global_dispatcher()
else:
control_flow_dispatcher = None

class MainDispatcher(Dispatcher):
def __init__(self, previous_dispatcher=None):
super().__init__({})
self.previous_dispatcher = previous_dispatcher

@property
def supported(self):
if self.previous_dispatcher:
return (
set(self.registered_functions) | self.previous_dispatcher.supported
)
return set(self.registered_functions)

def find_function(self, name: Any):
if self.previous_dispatcher:
find = self.previous_dispatcher.find_function(name)
if find:
return find
return Dispatcher.find_function(self, name)

def find_method(self, name: Any):
if self.previous_dispatcher:
find = self.previous_dispatcher.find_method(name)
if find:
return find
return Dispatcher.find_method(self, name)

main_dispatcher = MainDispatcher(control_flow_dispatcher)
if onnx_plugs:
for plug in onnx_plugs:
main_dispatcher.registered_functions[plug.target_name] = (
plug.custom_converter()
)
else:
main_dispatcher = None

return _to_onnx(
mod,
Expand All @@ -88,8 +142,9 @@ def to_onnx(
output_dynamic_shapes=output_dynamic_shapes,
export_options=ExportOptions(save_ep=save_ep),
options=options,
inline=inline,
dispatcher=main_dispatcher,
**(exporter_kwargs or {}),
dispatcher=dispatcher if use_control_flow_dispatcher else None,
)

if exporter in ("dynamo", "onnx-dynamo"):
Expand All @@ -99,6 +154,10 @@ def to_onnx(
assert (
not output_dynamic_shapes
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
custom_translation_table = {}
if onnx_plugs:
for plug in onnx_plugs:
custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
epo = torch.onnx.export(
mod,
args=args or tuple(),
Expand All @@ -111,9 +170,24 @@ def to_onnx(
verbose=verbose,
dump_exported_program=bool(save_ep),
artifacts_dir=os.path.dirname(filename) if filename else ".",
custom_translation_table=custom_translation_table,
**(exporter_kwargs or {}),
)
if optimize:
if not inline and optimize:
ort_fusions.optimize_for_ort(epo.model)

if onnx_plugs:
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
for func in irfunctions:
epo.model.functions[func.identifier()] = func
if inline:
common_passes.InlinePass()(epo.model)
common_passes.RemoveUnusedOpsetsPass()(epo.model)

if inline and optimize:
ort_fusions.optimize_for_ort(epo.model)
if filename:
epo.save(filename, external_data=True)
Expand Down
5 changes: 3 additions & 2 deletions onnx_diagnostic/export/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def register(self, aten_name: str, converter: Callable):

@contextlib.contextmanager
def enable_code_export_control_flow():
"""Enables the code means to be exported."""
"""Enables the code meant to be exported."""
global _TEST_EXPORT
old = _TEST_EXPORT
_TEST_EXPORT = True
Expand Down Expand Up @@ -134,7 +134,8 @@ def make_custom_loop_for(
assert body_outputs is not None, "body_outputs cannot be None"
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
name = f"loop_for_{body_fn.__name__}_{id(body_fn)}_{srank}_{sred}"
full_name = body_fn.__qualname__.replace("<locals>", "L").replace(".", "_")
name = f"loop_for_{full_name}_{srank}_{sred}"
if name in _REGISTERED_SCHEMA:
return name, _REGISTERED_SCHEMA[name][0]
sig = inspect.signature(body_fn)
Expand Down
Loading
Loading