From aff13658a72e8f85f704a5db85986f731a3960a9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sun, 27 Apr 2025 15:20:24 +0200 Subject: [PATCH 1/4] frsit draft --- _doc/api/torch_export_patches/index.rst | 1 + .../api/torch_export_patches/patch_module.rst | 7 + .../test_patch_module.py | 33 +++ .../torch_export_patches/patch_module.py | 234 ++++++++++++++++++ 4 files changed, 275 insertions(+) create mode 100644 _doc/api/torch_export_patches/patch_module.rst create mode 100644 _unittests/ut_torch_export_patches/test_patch_module.py create mode 100644 onnx_diagnostic/torch_export_patches/patch_module.py diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index b35e8c62..d47ec2f1 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -7,6 +7,7 @@ onnx_diagnostic.torch_export_patches patches/index patch_inputs + patch_module .. automodule:: onnx_diagnostic.torch_export_patches diff --git a/_doc/api/torch_export_patches/patch_module.rst b/_doc/api/torch_export_patches/patch_module.rst new file mode 100644 index 00000000..f3759993 --- /dev/null +++ b/_doc/api/torch_export_patches/patch_module.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.patch_module +================================================= + +.. automodule:: onnx_diagnostic.torch_export_patches.patch_module + :members: + :no-undoc-members: diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py new file mode 100644 index 00000000..a8330c06 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -0,0 +1,33 @@ +import ast +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.torch_export_patches.patch_module import transform_method + + +class TestPatchModule(ExtTestCase): + def test_rewrite_forward(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + if x.sum() > 0: + return x + y + else: + return torch.abs(x) + y + + x, y = torch.rand((3, 4)), torch.rand((3, 4)) + Model()(x, y) + tree, me = transform_method(Model.forward) + + print("-------------") + print(ast.dump(tree.body[0], indent=4)) + print("-------------") + code = ast.unparse(tree) + print(code) + print("-------------") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py new file mode 100644 index 00000000..7f47ef41 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -0,0 +1,234 @@ +import ast +import inspect +import types +import textwrap + + +class RewriteControlFlow(ast.NodeTransformer): + def __init__(self, wrapper_name): + self.wrapper_name = wrapper_name + self.counter = 0 + self.current_func_args = None + + def visit_FunctionDef(self, node): + # Capture argument names for branch functions + old_args = self.current_func_args + self.current_func_args = [arg.arg for arg in node.args.args] + node.body = [self.visit(n) for n in node.body] + self.current_func_args = old_args + return node + + def visit_If(self, node): + # First recurse into subnodes + node = self.generic_visit(node) + test_node = node.test + # Case 1: simple assignment in both branches + if ( + len(node.body) == 1 + and isinstance(node.body[0], ast.Assign) + and len(node.orelse) == 1 + and isinstance(node.orelse[0], ast.Assign) + and self.current_func_args is not None + ): + then_assign = node.body[0] + else_assign = node.orelse[0] + tgt = then_assign.targets[0] + if ( + isinstance(tgt, ast.Name) + and isinstance(else_assign.targets[0], ast.Name) + and tgt.id == else_assign.targets[0].id + ): + self.counter += 1 + then_name = f"{self.wrapper_name}_then_{self.counter}" + else_name = f"{self.wrapper_name}_else_{self.counter}" + then_expr = then_assign.value + else_expr = else_assign.value + # extract free variables + then_vars = sorted( + { + n.id + for n in ast.walk(then_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + else_vars = sorted( + { + n.id + for n in ast.walk(else_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + # build local funcs + then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] + then_def = ast.FunctionDef( + name=then_name, + args=ast.arguments( + posonlyargs=[], + args=then_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ast.Return(then_expr)], + decorator_list=[], + returns=None, + ) + else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] + else_def = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=else_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ast.Return(else_expr)], + decorator_list=[], + returns=None, + ) + # fix locations + for n in (then_def, else_def): + ast.copy_location(n, node) + ast.fix_missing_locations(n) + # wrapper call and assignment + then_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() + ) + else_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() + ) + call = ast.Call( + func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), + args=[ + test_node, + ast.Name(id=then_name, ctx=ast.Load()), + ast.Name(id=else_name, ctx=ast.Load()), + then_args_tuple, + else_args_tuple, + ], + keywords=[], + ) + assign = ast.Assign(targets=[tgt], value=call) + ast.copy_location(assign, node) + ast.fix_missing_locations(assign) + return [then_def, else_def, assign] + # Case 2: simple return in both branches + if ( + len(node.body) == 1 + and isinstance(node.body[0], ast.Return) + and len(node.orelse) == 1 + and isinstance(node.orelse[0], ast.Return) + and self.current_func_args is not None + ): + then_ret = node.body[0] + else_ret = node.orelse[0] + then_expr = then_ret.value + else_expr = else_ret.value + self.counter += 1 + then_name = f"{self.wrapper_name}_then_{self.counter}" + else_name = f"{self.wrapper_name}_else_{self.counter}" + # extract free variables + then_vars = sorted( + { + n.id + for n in ast.walk(then_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + else_vars = sorted( + { + n.id + for n in ast.walk(else_expr) + if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) + } + ) + # build local funcs + then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] + then_def = ast.FunctionDef( + name=then_name, + args=ast.arguments( + posonlyargs=[], args=then_args, kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=[ast.Return(then_expr)], + decorator_list=[], + returns=None, + ) + else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] + else_def = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], args=else_args, kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=[ast.Return(else_expr)], + decorator_list=[], + returns=None, + ) + for n in (then_def, else_def): + ast.copy_location(n, node) + ast.fix_missing_locations(n) + # wrapper call and return + then_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() + ) + else_args_tuple = ast.Tuple( + [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() + ) + call = ast.Call( + func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), + args=[ + test_node, + ast.Name(id=then_name, ctx=ast.Load()), + ast.Name(id=else_name, ctx=ast.Load()), + then_args_tuple, + else_args_tuple, + ], + keywords=[], + ) + ret = ast.Return(call) + ast.copy_location(ret, node) + ast.fix_missing_locations(ret) + return [then_def, else_def, ret] + return node + + def generic_visit(self, node): + return super().generic_visit(node) + + +def _fix_missing_locations_node(node): + if not hasattr(node, "lineno"): + node.lineno = 999 + for chi in ast.iter_child_nodes(node): + _fix_missing_locations_node(chi) + + +def _fix_missing_locations(new_tree): + for node in ast.walk(new_tree): + _fix_missing_locations_node(node) + + +def transform_method(func, wrapper_name="torch_cond"): + """ + Returns a new function based on `func` where every test (if, while, assert, + ternary, comparison, boolean op) is replaced by a call to `wrapper_name`. + + wrapper_name should refer to a function taking a single boolean argument. + """ + # Retrieve source of the function + src = inspect.getsource(func) + # Parse into AST + tree = ast.parse(textwrap.dedent(src)) + # Apply transformation + transformer = RewriteControlFlow(wrapper_name) + new_tree = transformer.visit(tree) + ast.fix_missing_locations(new_tree) + + # fix other location + _fix_missing_locations(new_tree) + mod = compile(new_tree, filename="", mode="exec") + namespace = {} + exec(mod, func.__globals__, namespace) + new_func = namespace.get(func.__name__) + if not isinstance(new_func, types.FunctionType): + raise RuntimeError("Transformed function not found") + return new_tree, new_func From 8e5f4545893499621fa2f2b4421bfaf595698825 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 29 Apr 2025 17:07:00 +0200 Subject: [PATCH 2/4] First test working --- .../test_patch_module.py | 20 +-- .../torch_export_patches/patch_module.py | 126 +++++++++++++----- 2 files changed, 105 insertions(+), 41 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index a8330c06..941465e6 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -1,4 +1,3 @@ -import ast import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase @@ -7,6 +6,7 @@ class TestPatchModule(ExtTestCase): def test_rewrite_forward(self): + class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -18,15 +18,17 @@ def forward(self, x, y): return torch.abs(x) + y x, y = torch.rand((3, 4)), torch.rand((3, 4)) + expected = Model()(x, y) + + rewritten = transform_method(Model.forward) + Model.forward = rewritten.func Model()(x, y) - tree, me = transform_method(Model.forward) - - print("-------------") - print(ast.dump(tree.body[0], indent=4)) - print("-------------") - code = ast.unparse(tree) - print(code) - print("-------------") + + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) + got = ep.module()(x, y) + self.assertEqualArray(expected, got) if __name__ == "__main__": diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 7f47ef41..20381cc8 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -2,6 +2,34 @@ import inspect import types import textwrap +from typing import Callable +import torch + +NODE_TYPES = tuple( + getattr(ast, k) + for k in dir(ast) + if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type) +) + + +def _settl(node, lineno, level=0): + if isinstance(node, (str, int, float)): + return node + if isinstance(node, list): + for n in node: + _settl(n, lineno, level=level + 1) + return node + if isinstance(node, NODE_TYPES): + if not hasattr(node, "lineno") or node.lineno is None: + node.lineno = lineno + for k in dir(node): + if k in {"s", "n"}: + continue + if k[0] == "_": + continue + v = getattr(node, k) + _settl(v, max(lineno, node.lineno), level=level + 1) + return node class RewriteControlFlow(ast.NodeTransformer): @@ -22,6 +50,7 @@ def visit_If(self, node): # First recurse into subnodes node = self.generic_visit(node) test_node = node.test + # Case 1: simple assignment in both branches if ( len(node.body) == 1 @@ -91,12 +120,15 @@ def visit_If(self, node): for n in (then_def, else_def): ast.copy_location(n, node) ast.fix_missing_locations(n) + assert hasattr(n, "lineno") # wrapper call and assignment then_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() + [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], + ctx=ast.Load(), ) else_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() + [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], + ctx=ast.Load(), ) call = ast.Call( func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), @@ -113,6 +145,7 @@ def visit_If(self, node): ast.copy_location(assign, node) ast.fix_missing_locations(assign) return [then_def, else_def, assign] + # Case 2: simple return in both branches if ( len(node.body) == 1 @@ -143,22 +176,33 @@ def visit_If(self, node): if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load) } ) + + then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch") + # build local funcs - then_args = [ast.arg(arg=v, annotation=None) for v in then_vars] + then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars] then_def = ast.FunctionDef( name=then_name, args=ast.arguments( - posonlyargs=[], args=then_args, kwonlyargs=[], kw_defaults=[], defaults=[] + posonlyargs=[], + args=then_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], ), body=[ast.Return(then_expr)], decorator_list=[], returns=None, ) - else_args = [ast.arg(arg=v, annotation=None) for v in else_vars] + else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars] else_def = ast.FunctionDef( name=else_name, args=ast.arguments( - posonlyargs=[], args=else_args, kwonlyargs=[], kw_defaults=[], defaults=[] + posonlyargs=[], + args=else_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], ), body=[ast.Return(else_expr)], decorator_list=[], @@ -168,20 +212,18 @@ def visit_If(self, node): ast.copy_location(n, node) ast.fix_missing_locations(n) # wrapper call and return - then_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load() - ) - else_args_tuple = ast.Tuple( - [ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load() + then_else_args_list = ast.List( + [ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars], + ctx=ast.Load(), ) + call = ast.Call( func=ast.Name(id=self.wrapper_name, ctx=ast.Load()), args=[ test_node, ast.Name(id=then_name, ctx=ast.Load()), ast.Name(id=else_name, ctx=ast.Load()), - then_args_tuple, - else_args_tuple, + then_else_args_list, ], keywords=[], ) @@ -195,24 +237,29 @@ def generic_visit(self, node): return super().generic_visit(node) -def _fix_missing_locations_node(node): - if not hasattr(node, "lineno"): - node.lineno = 999 - for chi in ast.iter_child_nodes(node): - _fix_missing_locations_node(chi) +class RewrittenMethod: + """ + Stores a rewritten method using + :func:`onnx_diagnostic.torch_export_patches.path_module.transform_method>`. + """ + + def __init__(self, tree, func): + self.tree = tree + self.func = func + @property + def code(self) -> str: + """Returns the source.""" + return ast.unparse(self.tree) -def _fix_missing_locations(new_tree): - for node in ast.walk(new_tree): - _fix_missing_locations_node(node) + def __repr__(self): + return f"{self.__class__.__name__}({self.func})" -def transform_method(func, wrapper_name="torch_cond"): +def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMethod: """ - Returns a new function based on `func` where every test (if, while, assert, - ternary, comparison, boolean op) is replaced by a call to `wrapper_name`. - - wrapper_name should refer to a function taking a single boolean argument. + Returns a new function based on `func` where every test (if) + is replaced by a call to :func:`torch.cond`. """ # Retrieve source of the function src = inspect.getsource(func) @@ -222,13 +269,28 @@ def transform_method(func, wrapper_name="torch_cond"): transformer = RewriteControlFlow(wrapper_name) new_tree = transformer.visit(tree) ast.fix_missing_locations(new_tree) - - # fix other location - _fix_missing_locations(new_tree) - mod = compile(new_tree, filename="", mode="exec") + _settl(new_tree, 0) + try: + mod = compile(new_tree, filename="", mode="exec") + except TypeError as e: + if 'required field "lineno" missing from stmt' in str(e): + # Could not find a way to avoid compilng a string. + # The error message still pops up without indicating which node is not + # properly set. + code = ast.unparse(new_tree) + mod = compile(code, filename="", mode="exec") + else: + kws = dict(include_attributes=True, annotate_fields=True, indent=4) + raise RuntimeError( + f"Unable to compile code\n--CODE--\n" + f"{ast.unparse(new_tree)}\n--TREE--\n" + f"{ast.dump(new_tree, **kws)}" + ) from e namespace = {} - exec(mod, func.__globals__, namespace) + globs = func.__globals__.copy() + globs["torch_cond"] = torch.cond + exec(mod, globs, namespace) new_func = namespace.get(func.__name__) if not isinstance(new_func, types.FunctionType): raise RuntimeError("Transformed function not found") - return new_tree, new_func + return RewrittenMethod(new_tree, new_func) From 027772d77b0b7d0cafafe99fbf4154925af83c30 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 29 Apr 2025 17:59:02 +0200 Subject: [PATCH 3/4] mypy --- onnx_diagnostic/torch_export_patches/patch_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 20381cc8..87182cde 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -2,7 +2,7 @@ import inspect import types import textwrap -from typing import Callable +from typing import Callable, Dict import torch NODE_TYPES = tuple( @@ -240,7 +240,7 @@ def generic_visit(self, node): class RewrittenMethod: """ Stores a rewritten method using - :func:`onnx_diagnostic.torch_export_patches.path_module.transform_method>`. + :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method>`. """ def __init__(self, tree, func): @@ -286,7 +286,7 @@ def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMeth f"{ast.unparse(new_tree)}\n--TREE--\n" f"{ast.dump(new_tree, **kws)}" ) from e - namespace = {} + namespace: Dict[type, type] = {} globs = func.__globals__.copy() globs["torch_cond"] = torch.cond exec(mod, globs, namespace) From 420977db4d2830039102c278d511fdf316fb3e2f Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 29 Apr 2025 18:44:31 +0200 Subject: [PATCH 4/4] add --- .../torch_export_patches/patch_module.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 87182cde..67b1e503 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -240,7 +240,10 @@ def generic_visit(self, node): class RewrittenMethod: """ Stores a rewritten method using - :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method>`. + :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`. + + :param tree: ast tree + :param func: callable compiled from the tree """ def __init__(self, tree, func): @@ -253,20 +256,25 @@ def code(self) -> str: return ast.unparse(self.tree) def __repr__(self): + "usual" return f"{self.__class__.__name__}({self.func})" -def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMethod: +def transform_method(func: Callable, if_name="torch_cond") -> RewrittenMethod: """ Returns a new function based on `func` where every test (if) is replaced by a call to :func:`torch.cond`. + + :param func: method or function to rewrite + :param if_name: function calling the test + :return: rewritten method """ # Retrieve source of the function src = inspect.getsource(func) # Parse into AST tree = ast.parse(textwrap.dedent(src)) # Apply transformation - transformer = RewriteControlFlow(wrapper_name) + transformer = RewriteControlFlow(if_name) new_tree = transformer.visit(tree) ast.fix_missing_locations(new_tree) _settl(new_tree, 0) @@ -286,9 +294,9 @@ def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMeth f"{ast.unparse(new_tree)}\n--TREE--\n" f"{ast.dump(new_tree, **kws)}" ) from e - namespace: Dict[type, type] = {} + namespace: Dict[str, type] = {} globs = func.__globals__.copy() - globs["torch_cond"] = torch.cond + globs[if_name] = torch.cond exec(mod, globs, namespace) new_func = namespace.get(func.__name__) if not isinstance(new_func, types.FunctionType):