From aff13658a72e8f85f704a5db85986f731a3960a9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sun, 27 Apr 2025 15:20:24 +0200 Subject: [PATCH 1/6] frsit draft --- _doc/api/torch_export_patches/index.rst | 1 + .../api/torch_export_patches/patch_module.rst | 7 + .../test_patch_module.py | 33 +++ .../torch_export_patches/patch_module.py | 234 ++++++++++++++++++ 4 files changed, 275 insertions(+) create mode 100644 _doc/api/torch_export_patches/patch_module.rst create mode 100644 _unittests/ut_torch_export_patches/test_patch_module.py create mode 100644 onnx_diagnostic/torch_export_patches/patch_module.py diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index b35e8c62..d47ec2f1 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -7,6 +7,7 @@ onnx_diagnostic.torch_export_patches patches/index patch_inputs + patch_module .. automodule:: onnx_diagnostic.torch_export_patches diff --git a/_doc/api/torch_export_patches/patch_module.rst b/_doc/api/torch_export_patches/patch_module.rst new file mode 100644 index 00000000..f3759993 --- /dev/null +++ b/_doc/api/torch_export_patches/patch_module.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.patch_module +================================================= + +.. automodule:: onnx_diagnostic.torch_export_patches.patch_module + :members: + :no-undoc-members: diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py new file mode 100644 index 00000000..a8330c06 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -0,0 +1,33 @@ +import ast +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.torch_export_patches.patch_module import transform_method + + +class TestPatchModule(ExtTestCase): + def test_rewrite_forward(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + + x, y = torch.rand((3, 4)), torch.rand((3, 4)) + Model()(x, y) + tree, me = transform_method(Model.forward) + + print("-------------") + print(ast.dump(tree.body[0], indent=4)) + print("-------------") + code = ast.unparse(tree) + print(code) + print("-------------") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py new file mode 100644 index 00000000..7f47ef41 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -0,0 +1,234 @@ +import ast +import inspect +import types +import textwrap + + +class RewriteControlFlow(ast.NodeTransformer): + def __init__(self, wrapper_name): + self.wrapper_name = wrapper_name + self.counter = 0 + self.current_func_args = None + + def visit_FunctionDef(self, node): + # Capture argument names for branch functions + old_args = self.current_func_args + self.current_func_args = [arg.arg for arg in node.args.args] + node.body = [self.visit(n) for n in node.body] + self.current_func_args = old_args + return node + + def visit_If(self, node): + # First recurse into subnodes + node = self.generic_visit(node) + test_node = node.test + # Case 1: simple assignment in both branches + if ( + len(node.body) == 1 + and isinstance(node.body[0], ast.Assign) + and len(node.orelse) == 1 + and isinstance(node.orelse[0], ast.Assign) + and self.current_func_args is not None + ): + then_assign = node.body[0] + else_assign = node.orelse[0] + tgt = then_assign.targets[0] + if ( + isinstance(tgt, ast.Name) + and isinstance(else_assign.targets[0], ast.Name) + and tgt.id == else_assign.targets[0].id + ): + self.counter += 1 + then_name = f"{self.wrapper_name}_then_{self.counter}" + else_name = f"{self.wrapper_name}_else_{self.counter}" + then_expr = then_assign.value + else_expr = else_assign.value + # extract free variables + then_vars = sorted( + { + n.id + for n in ast.walk(then_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + else_vars = sorted( + { + n.id + for n in ast.walk(else_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + # build local funcs + then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] + then_def = ast.FunctionDef( + name=then_name, + args=ast.arguments( + posonlyargs=[], + args=then_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ast.Return(then_expr)], + decorator_list=[], + returns=None, + ) + else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] + else_def = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=else_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ast.Return(else_expr)], + decorator_list=[], + returns=None, + ) + # fix locations + for n in (then_def, else_def): + ast.copy_location(n, node) + ast.fix_missing_locations(n) + # wrapper call and assignment + then_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() + ) + else_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() + ) + call = ast.Call( + func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), + args=[ + test_node, + ast.Name(id=then_name, ctx=ast.Load()), + ast.Name(id=else_name, ctx=ast.Load()), + then_args_tuple, + else_args_tuple, + ], + keywords=[], + ) + assign = ast.Assign(targets=[tgt], value=call) + ast.copy_location(assign, node) + ast.fix_missing_locations(assign) + return [then_def, else_def, assign] + # Case 2: simple return in both branches + if ( + len(node.body) == 1 + and isinstance(node.body[0], ast.Return) + and len(node.orelse) == 1 + and isinstance(node.orelse[0], ast.Return) + and self.current_func_args is not None + ): + then_ret = node.body[0] + else_ret = node.orelse[0] + then_expr = then_ret.value + else_expr = else_ret.value + self.counter += 1 + then_name = f"{self.wrapper_name}_then_{self.counter}" + else_name = f"{self.wrapper_name}_else_{self.counter}" + # extract free variables + then_vars = sorted( + { + n.id + for n in ast.walk(then_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + else_vars = sorted( + { + n.id + for n in ast.walk(else_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + # build local funcs + then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] + then_def = ast.FunctionDef( + name=then_name, + args=ast.arguments( + posonlyargs=[], args=then_args, kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=[ast.Return(then_expr)], + decorator_list=[], + returns=None, + ) + else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] + else_def = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], args=else_args, kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=[ast.Return(else_expr)], + decorator_list=[], + returns=None, + ) + for n in (then_def, else_def): + ast.copy_location(n, node) + ast.fix_missing_locations(n) + # wrapper call and return + then_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() + ) + else_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() + ) + call = ast.Call( + func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), + args=[ + test_node, + ast.Name(id=then_name, ctx=ast.Load()), + ast.Name(id=else_name, ctx=ast.Load()), + then_args_tuple, + else_args_tuple, + ], + keywords=[], + ) + ret = ast.Return(call) + ast.copy_location(ret, node) + ast.fix_missing_locations(ret) + return [then_def, else_def, ret] + return node + + def generic_visit(self, node): + return super().generic_visit(node) + + +def _fix_missing_locations_node(node): + if not hasattr(node, "lineno"): + node.lineno = 999 + for chi in ast.iter_child_nodes(node): + _fix_missing_locations_node(chi) + + +def _fix_missing_locations(new_tree): + for node in ast.walk(new_tree): + _fix_missing_locations_node(node) + + +def transform_method(func, wrapper_name="torch_cond"): + """ + Returns a new function based on `func` where every test (if, while, assert, + ternary, comparison, boolean op) is replaced by a call to `wrapper_name`. + + wrapper_name should refer to a function taking a single boolean argument. + """ + # Retrieve source of the function + src = inspect.getsource(func) + # Parse into AST + tree = ast.parse(textwrap.dedent(src)) + # Apply transformation + transformer = RewriteControlFlow(wrapper_name) + new_tree = transformer.visit(tree) + ast.fix_missing_locations(new_tree) + + # fix other location + _fix_missing_locations(new_tree) + mod = compile(new_tree, filename="", mode="exec") + namespace = {} + exec(mod, func.__globals__, namespace) + new_func = namespace.get(func.__name__) + if not isinstance(new_func, types.FunctionType): + raise RuntimeError("Transformed function not found") + return new_tree, new_func From 05b04f7120e9c53420d34802bc37480ba3215056 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 28 Apr 2025 16:03:53 +0200 Subject: [PATCH 2/6] add support for os_ort --- _doc/examples/plot_export_tiny_llm_patched.py | 2 +- _doc/index.rst | 1 + _doc/recipes/plot_dynamic_shapes_max.py | 4 ++- .../ut_torch_models/test_test_helpers.py | 23 ++++++++++++- onnx_diagnostic/__init__.py | 2 +- onnx_diagnostic/torch_models/test_helper.py | 33 +++++++++++-------- 6 files changed, 48 insertions(+), 17 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm_patched.py b/_doc/examples/plot_export_tiny_llm_patched.py index 60a20d15..5ed9566e 100644 --- a/_doc/examples/plot_export_tiny_llm_patched.py +++ b/_doc/examples/plot_export_tiny_llm_patched.py @@ -101,7 +101,7 @@ # %% # If they are not registered, function -# func:`onnx_diagnostic.torch_export_patches.torch_export_patches` +# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` # should take care of it. Then we export. with torch_export_patches(patch_transformers=True, verbose=10) as modificator: diff --git a/_doc/index.rst b/_doc/index.rst index 6d4792dd..5d105c5e 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -173,6 +173,7 @@ Size of the package: Older versions ++++++++++++++ +* `0.4.4 <../v0.4.4/index.html>`_ * `0.4.3 <../v0.4.3/index.html>`_ * `0.4.2 <../v0.4.2/index.html>`_ * `0.4.1 <../v0.4.1/index.html>`_ diff --git a/_doc/recipes/plot_dynamic_shapes_max.py b/_doc/recipes/plot_dynamic_shapes_max.py index 83844c7c..1d681740 100644 --- a/_doc/recipes/plot_dynamic_shapes_max.py +++ b/_doc/recipes/plot_dynamic_shapes_max.py @@ -185,4 +185,6 @@ def forward(self, x, y, fact): # is hidden in a custom operator. -doc.plot_legend("max(d1, d2)\nwith d1, d2 dimensions", "dynamic shapes", "green") +doc.plot_legend( + "Fixed in torch==2.8\nmax(d1, d2)\nwith d1, d2\ndimensions", "dynamic shapes", "green" +) diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index 1a2a99d9..03928f58 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -66,7 +66,7 @@ def test_validate_model_export(self): @requires_torch("2.7") @hide_stdout() @ignore_warnings(FutureWarning) - def test_validate_model_onnx_dynamo(self): + def test_validate_model_onnx_dynamo_ir(self): mid = "arnir0/Tiny-LLM" summary, data = validate_model( mid, @@ -87,6 +87,27 @@ def test_validate_model_onnx_dynamo(self): onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10 ) + @requires_torch("2.7") + @hide_stdout() + @ignore_warnings(FutureWarning) + def test_validate_model_onnx_dynamo_os_ort(self): + mid = "arnir0/Tiny-LLM" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="onnx-dynamo", + dump_folder="dump_test_validate_model_onnx_dynamo", + patch=True, + stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + optimization="os_ort", + ) + self.assertIsInstance(summary, dict) + self.assertIsInstance(data, dict) + self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + onnx_filename = data["onnx_filename"] + self.assertExists(onnx_filename) + @requires_torch("2.7") @hide_stdout() @ignore_warnings(FutureWarning) diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 2d02cd2f..7a45dfe1 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.4.3" +__version__ = "0.4.4" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index aa93ece7..7620dcbb 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -4,6 +4,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import onnx +import onnxscript +import onnxscript.rewriter.ort_fusions as ort_fusions import torch from ..export import CoupleInputsDynamicShapes from ..helpers import max_diff, string_type, string_diff @@ -917,11 +919,10 @@ def call_torch_export_onnx( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - assert optimization in { - "", - "ir", - None, - }, f"unexpected value for optimization={optimization}" + available = {"", "ir", "os_ort"} + assert ( + optimization in available + ), f"unexpected value for optimization={optimization}, available={available}" assert exporter in { "onnx-dynamo", "onnx-script", @@ -1001,16 +1002,22 @@ def call_torch_export_onnx( print(epo) print("[call_torch_export_onnx] -- End of ONNXProgram") - if optimization == "ir": + if optimization in {"ir", "os_ort"}: if verbose: print(f"[call_torch_export_onnx] starts optimization={optimization!r}...") - _quiet_or_not_quiet( - quiet, - "export_onnx_opt_ir", - summary, - data, - (lambda epo=epo: epo.optimize()), - ) + if optimization == "ir": + label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize()) + else: + + def _os_ort_optim(epo): + onnxscript.optimizer.optimize_ir(epo.model) + optimized = ort_fusions.optimize_for_ort(epo.model) + epo.model = ( + optimized if isinstance(optimized, onnxscript.ir.Model) else optimized[0] + ) + + label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo)) + _quiet_or_not_quiet(quiet, label, summary, data, f_optim) if "ERR_export_onnx_opt_ir" in summary: return summary, data if verbose: From 194ffc2689a59ba416e81ab72b0b4c187fa0cdc7 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 28 Apr 2025 16:47:19 +0200 Subject: [PATCH 3/6] pass --- _doc/api/torch_export_patches/index.rst | 1 - .../api/torch_export_patches/patch_module.rst | 7 - .../test_patch_module.py | 33 --- .../torch_export_patches/patch_module.py | 234 ------------------ onnx_diagnostic/torch_models/test_helper.py | 16 +- 5 files changed, 14 insertions(+), 277 deletions(-) delete mode 100644 _doc/api/torch_export_patches/patch_module.rst delete mode 100644 _unittests/ut_torch_export_patches/test_patch_module.py delete mode 100644 onnx_diagnostic/torch_export_patches/patch_module.py diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index d47ec2f1..b35e8c62 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -7,7 +7,6 @@ onnx_diagnostic.torch_export_patches patches/index patch_inputs - patch_module .. automodule:: onnx_diagnostic.torch_export_patches diff --git a/_doc/api/torch_export_patches/patch_module.rst b/_doc/api/torch_export_patches/patch_module.rst deleted file mode 100644 index f3759993..00000000 --- a/_doc/api/torch_export_patches/patch_module.rst +++ /dev/null @@ -1,7 +0,0 @@ - -onnx_diagnostic.torch_export_patches.patch_module -================================================= - -.. automodule:: onnx_diagnostic.torch_export_patches.patch_module - :members: - :no-undoc-members: diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py deleted file mode 100644 index a8330c06..00000000 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ /dev/null @@ -1,33 +0,0 @@ -import ast -import unittest -import torch -from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.torch_export_patches.patch_module import transform_method - - -class TestPatchModule(ExtTestCase): - def test_rewrite_forward(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - if x.sum() > 0: - return x + y - else: - return torch.abs(x) + y - - x, y = torch.rand((3, 4)), torch.rand((3, 4)) - Model()(x, y) - tree, me = transform_method(Model.forward) - - print("-------------") - print(ast.dump(tree.body[0], indent=4)) - print("-------------") - code = ast.unparse(tree) - print(code) - print("-------------") - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py deleted file mode 100644 index 7f47ef41..00000000 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ /dev/null @@ -1,234 +0,0 @@ -import ast -import inspect -import types -import textwrap - - -class RewriteControlFlow(ast.NodeTransformer): - def __init__(self, wrapper_name): - self.wrapper_name = wrapper_name - self.counter = 0 - self.current_func_args = None - - def visit_FunctionDef(self, node): - # Capture argument names for branch functions - old_args = self.current_func_args - self.current_func_args = [arg.arg for arg in node.args.args] - node.body = [self.visit(n) for n in node.body] - self.current_func_args = old_args - return node - - def visit_If(self, node): - # First recurse into subnodes - node = self.generic_visit(node) - test_node = node.test - # Case 1: simple assignment in both branches - if ( - len(node.body) == 1 - and isinstance(node.body[0], ast.Assign) - and len(node.orelse) == 1 - and isinstance(node.orelse[0], ast.Assign) - and self.current_func_args is not None - ): - then_assign = node.body[0] - else_assign = node.orelse[0] - tgt = then_assign.targets[0] - if ( - isinstance(tgt, ast.Name) - and isinstance(else_assign.targets[0], ast.Name) - and tgt.id == else_assign.targets[0].id - ): - self.counter += 1 - then_name = f"{self.wrapper_name}_then_{self.counter}" - else_name = f"{self.wrapper_name}_else_{self.counter}" - then_expr = then_assign.value - else_expr = else_assign.value - # extract free variables - then_vars = sorted( - { - n.id - for n in ast.walk(then_expr) - if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) - } - ) - else_vars = sorted( - { - n.id - for n in ast.walk(else_expr) - if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) - } - ) - # build local funcs - then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] - then_def = ast.FunctionDef( - name=then_name, - args=ast.arguments( - posonlyargs=[], - args=then_args, - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=[ast.Return(then_expr)], - decorator_list=[], - returns=None, - ) - else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] - else_def = ast.FunctionDef( - name=else_name, - args=ast.arguments( - posonlyargs=[], - args=else_args, - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=[ast.Return(else_expr)], - decorator_list=[], - returns=None, - ) - # fix locations - for n in (then_def, else_def): - ast.copy_location(n, node) - ast.fix_missing_locations(n) - # wrapper call and assignment - then_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() - ) - else_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() - ) - call = ast.Call( - func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), - args=[ - test_node, - ast.Name(id=then_name, ctx=ast.Load()), - ast.Name(id=else_name, ctx=ast.Load()), - then_args_tuple, - else_args_tuple, - ], - keywords=[], - ) - assign = ast.Assign(targets=[tgt], value=call) - ast.copy_location(assign, node) - ast.fix_missing_locations(assign) - return [then_def, else_def, assign] - # Case 2: simple return in both branches - if ( - len(node.body) == 1 - and isinstance(node.body[0], ast.Return) - and len(node.orelse) == 1 - and isinstance(node.orelse[0], ast.Return) - and self.current_func_args is not None - ): - then_ret = node.body[0] - else_ret = node.orelse[0] - then_expr = then_ret.value - else_expr = else_ret.value - self.counter += 1 - then_name = f"{self.wrapper_name}_then_{self.counter}" - else_name = f"{self.wrapper_name}_else_{self.counter}" - # extract free variables - then_vars = sorted( - { - n.id - for n in ast.walk(then_expr) - if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) - } - ) - else_vars = sorted( - { - n.id - for n in ast.walk(else_expr) - if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) - } - ) - # build local funcs - then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] - then_def = ast.FunctionDef( - name=then_name, - args=ast.arguments( - posonlyargs=[], args=then_args, kwonlyargs=[], kw_defaults=[], defaults=[] - ), - body=[ast.Return(then_expr)], - decorator_list=[], - returns=None, - ) - else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] - else_def = ast.FunctionDef( - name=else_name, - args=ast.arguments( - posonlyargs=[], args=else_args, kwonlyargs=[], kw_defaults=[], defaults=[] - ), - body=[ast.Return(else_expr)], - decorator_list=[], - returns=None, - ) - for n in (then_def, else_def): - ast.copy_location(n, node) - ast.fix_missing_locations(n) - # wrapper call and return - then_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() - ) - else_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() - ) - call = ast.Call( - func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), - args=[ - test_node, - ast.Name(id=then_name, ctx=ast.Load()), - ast.Name(id=else_name, ctx=ast.Load()), - then_args_tuple, - else_args_tuple, - ], - keywords=[], - ) - ret = ast.Return(call) - ast.copy_location(ret, node) - ast.fix_missing_locations(ret) - return [then_def, else_def, ret] - return node - - def generic_visit(self, node): - return super().generic_visit(node) - - -def _fix_missing_locations_node(node): - if not hasattr(node, "lineno"): - node.lineno = 999 - for chi in ast.iter_child_nodes(node): - _fix_missing_locations_node(chi) - - -def _fix_missing_locations(new_tree): - for node in ast.walk(new_tree): - _fix_missing_locations_node(node) - - -def transform_method(func, wrapper_name="torch_cond"): - """ - Returns a new function based on `func` where every test (if, while, assert, - ternary, comparison, boolean op) is replaced by a call to `wrapper_name`. - - wrapper_name should refer to a function taking a single boolean argument. - """ - # Retrieve source of the function - src = inspect.getsource(func) - # Parse into AST - tree = ast.parse(textwrap.dedent(src)) - # Apply transformation - transformer = RewriteControlFlow(wrapper_name) - new_tree = transformer.visit(tree) - ast.fix_missing_locations(new_tree) - - # fix other location - _fix_missing_locations(new_tree) - mod = compile(new_tree, filename="", mode="exec") - namespace = {} - exec(mod, func.__globals__, namespace) - new_func = namespace.get(func.__name__) - if not isinstance(new_func, types.FunctionType): - raise RuntimeError("Transformed function not found") - return new_tree, new_func diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 7620dcbb..24903230 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -1046,12 +1046,17 @@ def call_torch_export_custom( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - assert optimization in { + available = { "", "default", "default+onnxruntime", + "default+os_ort", + "default+onnxruntime+os_ort", None, - }, f"unexpected value for optimization={optimization}" + } + assert ( + optimization in available + ), f"unexpected value for optimization={optimization}, available={available}" assert exporter in { "custom", "custom-strict", @@ -1085,6 +1090,10 @@ def call_torch_export_custom( from experimental_experiment.torch_interpreter import to_onnx, ExportOptions from experimental_experiment.xbuilder import OptimizationOptions + spl = optimization.split("+") if optimization else [] + os_ort = "os_ort" in spl + optimization = "+".join(_ for _ in spl if _ != "os_ort") + export_options = ExportOptions( strict=strict, decomposition_table=( @@ -1188,6 +1197,9 @@ def call_torch_export_custom( assert epo is not None, "no onnx export was found" if verbose: print("[call_torch_export_custom] done (export)") + + if os_ort: + pass data["onnx_program"] = epo return summary, data From a5d35555dc56c08832ec71fe8c3c00c854146e92 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 28 Apr 2025 19:35:37 +0200 Subject: [PATCH 4/6] fix issues --- onnx_diagnostic/torch_models/test_helper.py | 33 ++++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 24903230..06c907de 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -1012,9 +1012,12 @@ def call_torch_export_onnx( def _os_ort_optim(epo): onnxscript.optimizer.optimize_ir(epo.model) optimized = ort_fusions.optimize_for_ort(epo.model) - epo.model = ( - optimized if isinstance(optimized, onnxscript.ir.Model) else optimized[0] - ) + if isinstance(optimized, tuple): + for k, v in optimized[1].items(): + summary[f"op_opt_fused_{k}"] = v + epo.model = optimized[0] + else: + epo.model = optimized label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo)) _quiet_or_not_quiet(quiet, label, summary, data, f_optim) @@ -1199,7 +1202,29 @@ def call_torch_export_custom( print("[call_torch_export_custom] done (export)") if os_ort: - pass + if verbose: + print("[call_torch_export_custom] conversion to IR...") + begin = time.perf_counter() + ir_model = epo.to_ir() + duration = time.perf_counter() - begin + summary["time_optim_to_ir"] = duration + if verbose: + print(f"[call_torch_export_custom] done in {duration}") + print("[call_torch_export_custom] start optimization...") + begin = time.perf_counter() + onnxscript.optimizer.optimize_ir(ir_model) + ir_optimized = ort_fusions.optimize_for_ort(ir_model) + if isinstance(ir_optimized, tuple): + report = ir_optimized[1] + for k, v in report.items(): + summary[f"op_opt_fused_{k}"] = v + ir_optimized = ir_optimized[0] + epo.model = ir_optimized + duration = time.perf_counter() - begin + summary["time_optim_os_ort"] = duration + if verbose: + print(f"[call_torch_export_custom] done in {duration}") + data["onnx_program"] = epo return summary, data From 08537b889032e54e71b3dce6a0d0c4354124e582 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 29 Apr 2025 11:58:11 +0200 Subject: [PATCH 5/6] ut --- .../ut_torch_models/test_test_helpers.py | 22 +++++ ...rnir0_Tiny-LLM-custom-default+os_ort.stats | 88 +++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index 03928f58..0ebb5918 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -108,6 +108,28 @@ def test_validate_model_onnx_dynamo_os_ort(self): onnx_filename = data["onnx_filename"] self.assertExists(onnx_filename) + @requires_torch("2.7") + @hide_stdout() + @ignore_warnings(FutureWarning) + @requires_experimental() + def test_validate_model_custom_os_ort(self): + mid = "arnir0/Tiny-LLM" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="custom", + dump_folder="test_validate_model_custom_os_ort", + patch=True, + stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + optimization="default+os_ort", + ) + self.assertIsInstance(summary, dict) + self.assertIsInstance(data, dict) + self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + onnx_filename = data["onnx_filename"] + self.assertExists(onnx_filename) + @requires_torch("2.7") @hide_stdout() @ignore_warnings(FutureWarning) diff --git a/test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats b/test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats new file mode 100644 index 00000000..92ec6e1f --- /dev/null +++ b/test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats @@ -0,0 +1,88 @@ +:disc_onnx_ort_run_abs:8.344650268554688e-07; +:disc_onnx_ort_run_dnan:0; +:disc_onnx_ort_run_n:204672.0; +:disc_onnx_ort_run_rel:0.00033429367266728856; +:disc_onnx_ort_run_sum:0.018881732111708516; +:disc_patched_abs:0; +:disc_patched_dnan:0; +:disc_patched_n:204672.0; +:disc_patched_rel:0; +:disc_patched_sum:0.0; +:dump_folder:test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort; +:dump_folder_name:arnir0_Tiny-LLM-custom-default+os_ort; +:export_args:(); +:export_exporter:custom; +:export_kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96])); +:export_optimization:default+os_ort; +:export_strict:False; +:model_class:LlamaForCausalLM; +:model_config:{'vocab_size':32000,'max_position_embeddings':1024,'hidden_size':192,'intermediate_size':1024,'num_hidden_layers':1,'num_attention_heads':2,'num_key_value_heads':1,'hidden_act':'silu','initializer_range':0.02,'rms_norm_eps':1e-05,'pretraining_tp':1,'use_cache':True,'rope_theta':10000.0,'rope_scaling':None,'attention_bias':False,'attention_dropout':0.0,'mlp_bias':False,'head_dim':96,'return_dict':True,'output_hidden_states':False,'output_attentions':False,'torchscript':False,'torch_dtype':'float32','use_bfloat16':False,'tf_legacy_loss':False,'pruned_heads':{},'tie_word_embeddings':False,'chunk_size_feed_forward':0,'is_encoder_decoder':False,'is_decoder':False,'cross_attention_hidden_size':None,'add_cross_attention':False,'tie_encoder_decoder':False,'max_length':20,'min_length':0,'do_sample':False,'early_stopping':False,'num_beams':1,'num_beam_groups':1,'diversity_penalty':0.0,'temperature':1.0,'top_k':50,'top_p':1.0,'typical_p':1.0,'repetition_penalty':1.0,'length_penalty':1.0,'no_repeat_ngram_size':0,'encoder_no_repeat_ngram_size':0,'bad_words_ids':None,'num_return_sequences':1,'output_scores':False,'return_dict_in_generate':False,'forced_bos_token_id':None,'forced_eos_token_id':None,'remove_invalid_values':False,'exponential_decay_length_penalty':None,'suppress_tokens':None,'begin_suppress_tokens':None,'architectures':['LlamaForCausalLM'],'finetuning_task':None,'id2label':{0:'LABEL_0',1:'LABEL_1'},'label2id':{'LABEL_0':0,'LABEL_1':1},'tokenizer_class':None,'prefix':None,'bos_token_id':1,'pad_token_id':None,'eos_token_id':2,'sep_token_id':None,'decoder_start_token_id':None,'task_specific_params':None,'problem_type':None,'_attn_implementation_autoset':True,'transformers_version':'4.52.0.dev0','model_type':'llama'}; +:model_config_class:LlamaConfig; +:model_expected:CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96])); +:model_id:arnir0/Tiny-LLM; +:model_inputs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96])); +:model_inputs_opionts:; +:model_nweights:12988992; +:model_shapes:str; +:model_size:51955968; +:model_task:text-generation; +:onnx_filename:test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.onnx; +:onnx_opt_optimized:1; +:onnx_ort_inputs:dict(input_ids:A7s2x3,attention_mask:A7s2x33,position_ids:A7s2x3,past_key_values_key_cache_0:A1s2x1x30x96,past_key_values_value_cache_0:A1s2x1x30x96); +:onnx_size:43892; +:op_opt_all_added:117; +:op_opt_all_removed:218; +:op_opt_all_time_in:0.41514878199086525; +:op_opt_cst_added:0; +:op_opt_cst_removed:0; +:op_opt_cst_time_in:0.0; +:op_opt_export_optimization:0.41514878199086525; +:op_opt_fused_attention:0; +:op_opt_fused_bias_gelu:0; +:op_opt_fused_cos_sin_cache:0; +:op_opt_fused_erf_gelu:0; +:op_opt_fused_gelu:0; +:op_opt_fused_gqa:0; +:op_opt_fused_mha:0; +:op_opt_fused_packed_qkv_for_gqa:0; +:op_opt_fused_partial_rotary_embedding:0; +:op_opt_fused_rms_normalization:0; +:op_opt_fused_rotary_embedding:0; +:op_opt_fused_sdpa:1; +:op_opt_fused_skip_layer_normalization:0; +:op_opt_fused_skip_rms_normalization:0; +:op_opt_max_iter:13; +:op_opt_n_applied:49; +:op_opt_unique_applied:24; +:op_opt_unique_matched:48; +:time_create:5.664649608999753; +:time_export_export_onnx_c:7.78976812899964; +:time_export_optimization:0.41514878199086525; +:time_onnx_save:0.21937757000159763; +:time_optim_os_ort:0.2757421719998092; +:time_optim_to_ir:0.004986361000192119; +:time_run:0.04424335000112478; +:time_run_patched:0.006461869001213927; +:time_time_onnx_ort_create:0.10978344900104275; +:time_time_onnx_ort_run:0.14714581799853477; +:version_date:2025-04-29T11:56:42; +:version_device:; +:version_do_run:True; +:version_drop_inputs:[]; +:version_dtype:; +:version_dump_folder:test_validate_model_custom_os_ort; +:version_exporter:custom; +:version_model_id:arnir0/Tiny-LLM; +:version_numpy:2.2.5; +:version_onnx:1.19.0; +:version_onnx_diagnostic:0.4.4; +:version_onnxruntime:1.22.0+cu126; +:version_onnxscript:0.3.0.dev20250301; +:version_optimization:default+os_ort; +:version_ortfusiontype:; +:version_patch:True; +:version_quiet:False; +:version_stop_if_static:2; +:version_torch:2.8.0.dev20250423+cu126; +:version_trained:False; +:version_transformers:4.52.0.dev0; From faad2c0d691b9762d6c50c04fefb34f6a4f7843f Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 29 Apr 2025 11:58:27 +0200 Subject: [PATCH 6/6] clean --- .../ut_torch_models/test_test_helpers.py | 2 +- ...rnir0_Tiny-LLM-custom-default+os_ort.stats | 88 ------------------- 2 files changed, 1 insertion(+), 89 deletions(-) delete mode 100644 test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index 0ebb5918..90dc1b52 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -119,7 +119,7 @@ def test_validate_model_custom_os_ort(self): do_run=True, verbose=10, exporter="custom", - dump_folder="test_validate_model_custom_os_ort", + dump_folder="dump_validate_model_custom_os_ort", patch=True, stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, optimization="default+os_ort", diff --git a/test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats b/test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats deleted file mode 100644 index 92ec6e1f..00000000 --- a/test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.stats +++ /dev/null @@ -1,88 +0,0 @@ -:disc_onnx_ort_run_abs:8.344650268554688e-07; -:disc_onnx_ort_run_dnan:0; -:disc_onnx_ort_run_n:204672.0; -:disc_onnx_ort_run_rel:0.00033429367266728856; -:disc_onnx_ort_run_sum:0.018881732111708516; -:disc_patched_abs:0; -:disc_patched_dnan:0; -:disc_patched_n:204672.0; -:disc_patched_rel:0; -:disc_patched_sum:0.0; -:dump_folder:test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort; -:dump_folder_name:arnir0_Tiny-LLM-custom-default+os_ort; -:export_args:(); -:export_exporter:custom; -:export_kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96])); -:export_optimization:default+os_ort; -:export_strict:False; -:model_class:LlamaForCausalLM; -:model_config:{'vocab_size':32000,'max_position_embeddings':1024,'hidden_size':192,'intermediate_size':1024,'num_hidden_layers':1,'num_attention_heads':2,'num_key_value_heads':1,'hidden_act':'silu','initializer_range':0.02,'rms_norm_eps':1e-05,'pretraining_tp':1,'use_cache':True,'rope_theta':10000.0,'rope_scaling':None,'attention_bias':False,'attention_dropout':0.0,'mlp_bias':False,'head_dim':96,'return_dict':True,'output_hidden_states':False,'output_attentions':False,'torchscript':False,'torch_dtype':'float32','use_bfloat16':False,'tf_legacy_loss':False,'pruned_heads':{},'tie_word_embeddings':False,'chunk_size_feed_forward':0,'is_encoder_decoder':False,'is_decoder':False,'cross_attention_hidden_size':None,'add_cross_attention':False,'tie_encoder_decoder':False,'max_length':20,'min_length':0,'do_sample':False,'early_stopping':False,'num_beams':1,'num_beam_groups':1,'diversity_penalty':0.0,'temperature':1.0,'top_k':50,'top_p':1.0,'typical_p':1.0,'repetition_penalty':1.0,'length_penalty':1.0,'no_repeat_ngram_size':0,'encoder_no_repeat_ngram_size':0,'bad_words_ids':None,'num_return_sequences':1,'output_scores':False,'return_dict_in_generate':False,'forced_bos_token_id':None,'forced_eos_token_id':None,'remove_invalid_values':False,'exponential_decay_length_penalty':None,'suppress_tokens':None,'begin_suppress_tokens':None,'architectures':['LlamaForCausalLM'],'finetuning_task':None,'id2label':{0:'LABEL_0',1:'LABEL_1'},'label2id':{'LABEL_0':0,'LABEL_1':1},'tokenizer_class':None,'prefix':None,'bos_token_id':1,'pad_token_id':None,'eos_token_id':2,'sep_token_id':None,'decoder_start_token_id':None,'task_specific_params':None,'problem_type':None,'_attn_implementation_autoset':True,'transformers_version':'4.52.0.dev0','model_type':'llama'}; -:model_config_class:LlamaConfig; -:model_expected:CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96])); -:model_id:arnir0/Tiny-LLM; -:model_inputs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96])); -:model_inputs_opionts:; -:model_nweights:12988992; -:model_shapes:str; -:model_size:51955968; -:model_task:text-generation; -:onnx_filename:test_validate_model_custom_os_ort/arnir0_Tiny-LLM-custom-default+os_ort/arnir0_Tiny-LLM-custom-default+os_ort.onnx; -:onnx_opt_optimized:1; -:onnx_ort_inputs:dict(input_ids:A7s2x3,attention_mask:A7s2x33,position_ids:A7s2x3,past_key_values_key_cache_0:A1s2x1x30x96,past_key_values_value_cache_0:A1s2x1x30x96); -:onnx_size:43892; -:op_opt_all_added:117; -:op_opt_all_removed:218; -:op_opt_all_time_in:0.41514878199086525; -:op_opt_cst_added:0; -:op_opt_cst_removed:0; -:op_opt_cst_time_in:0.0; -:op_opt_export_optimization:0.41514878199086525; -:op_opt_fused_attention:0; -:op_opt_fused_bias_gelu:0; -:op_opt_fused_cos_sin_cache:0; -:op_opt_fused_erf_gelu:0; -:op_opt_fused_gelu:0; -:op_opt_fused_gqa:0; -:op_opt_fused_mha:0; -:op_opt_fused_packed_qkv_for_gqa:0; -:op_opt_fused_partial_rotary_embedding:0; -:op_opt_fused_rms_normalization:0; -:op_opt_fused_rotary_embedding:0; -:op_opt_fused_sdpa:1; -:op_opt_fused_skip_layer_normalization:0; -:op_opt_fused_skip_rms_normalization:0; -:op_opt_max_iter:13; -:op_opt_n_applied:49; -:op_opt_unique_applied:24; -:op_opt_unique_matched:48; -:time_create:5.664649608999753; -:time_export_export_onnx_c:7.78976812899964; -:time_export_optimization:0.41514878199086525; -:time_onnx_save:0.21937757000159763; -:time_optim_os_ort:0.2757421719998092; -:time_optim_to_ir:0.004986361000192119; -:time_run:0.04424335000112478; -:time_run_patched:0.006461869001213927; -:time_time_onnx_ort_create:0.10978344900104275; -:time_time_onnx_ort_run:0.14714581799853477; -:version_date:2025-04-29T11:56:42; -:version_device:; -:version_do_run:True; -:version_drop_inputs:[]; -:version_dtype:; -:version_dump_folder:test_validate_model_custom_os_ort; -:version_exporter:custom; -:version_model_id:arnir0/Tiny-LLM; -:version_numpy:2.2.5; -:version_onnx:1.19.0; -:version_onnx_diagnostic:0.4.4; -:version_onnxruntime:1.22.0+cu126; -:version_onnxscript:0.3.0.dev20250301; -:version_optimization:default+os_ort; -:version_ortfusiontype:; -:version_patch:True; -:version_quiet:False; -:version_stop_if_static:2; -:version_torch:2.8.0.dev20250423+cu126; -:version_trained:False; -:version_transformers:4.52.0.dev0;