Skip to content

Commit aff1365

Browse files
committed
frsit draft
1 parent 7eb831c commit aff1365

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ onnx_diagnostic.torch_export_patches
77

88
patches/index
99
patch_inputs
10+
patch_module
1011

1112

1213
.. automodule:: onnx_diagnostic.torch_export_patches
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_module
3+
=================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module
6+
:members:
7+
:no-undoc-members:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import ast
2+
import unittest
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.torch_export_patches.patch_module import transform_method
6+
7+
8+
class TestPatchModule(ExtTestCase):
9+
def test_rewrite_forward(self):
10+
class Model(torch.nn.Module):
11+
def __init__(self):
12+
super().__init__()
13+
14+
def forward(self, x, y):
15+
if x.sum() > 0:
16+
return x + y
17+
else:
18+
return torch.abs(x) + y
19+
20+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
21+
Model()(x, y)
22+
tree, me = transform_method(Model.forward)
23+
24+
print("-------------")
25+
print(ast.dump(tree.body[0], indent=4))
26+
print("-------------")
27+
code = ast.unparse(tree)
28+
print(code)
29+
print("-------------")
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main(verbosity=2)
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import ast
2+
import inspect
3+
import types
4+
import textwrap
5+
6+
7+
class RewriteControlFlow(ast.NodeTransformer):
8+
def __init__(self, wrapper_name):
9+
self.wrapper_name = wrapper_name
10+
self.counter = 0
11+
self.current_func_args = None
12+
13+
def visit_FunctionDef(self, node):
14+
# Capture argument names for branch functions
15+
old_args = self.current_func_args
16+
self.current_func_args = [arg.arg for arg in node.args.args]
17+
node.body = [self.visit(n) for n in node.body]
18+
self.current_func_args = old_args
19+
return node
20+
21+
def visit_If(self, node):
22+
# First recurse into subnodes
23+
node = self.generic_visit(node)
24+
test_node = node.test
25+
# Case 1: simple assignment in both branches
26+
if (
27+
len(node.body) == 1
28+
and isinstance(node.body[0], ast.Assign)
29+
and len(node.orelse) == 1
30+
and isinstance(node.orelse[0], ast.Assign)
31+
and self.current_func_args is not None
32+
):
33+
then_assign = node.body[0]
34+
else_assign = node.orelse[0]
35+
tgt = then_assign.targets[0]
36+
if (
37+
isinstance(tgt, ast.Name)
38+
and isinstance(else_assign.targets[0], ast.Name)
39+
and tgt.id == else_assign.targets[0].id
40+
):
41+
self.counter += 1
42+
then_name = f"{self.wrapper_name}_then_{self.counter}"
43+
else_name = f"{self.wrapper_name}_else_{self.counter}"
44+
then_expr = then_assign.value
45+
else_expr = else_assign.value
46+
# extract free variables
47+
then_vars = sorted(
48+
{
49+
n.id
50+
for n in ast.walk(then_expr)
51+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
52+
}
53+
)
54+
else_vars = sorted(
55+
{
56+
n.id
57+
for n in ast.walk(else_expr)
58+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
59+
}
60+
)
61+
# build local funcs
62+
then_args = [ast.arg(arg=v, annotation=None) for v in then_vars]
63+
then_def = ast.FunctionDef(
64+
name=then_name,
65+
args=ast.arguments(
66+
posonlyargs=[],
67+
args=then_args,
68+
kwonlyargs=[],
69+
kw_defaults=[],
70+
defaults=[],
71+
),
72+
body=[ast.Return(then_expr)],
73+
decorator_list=[],
74+
returns=None,
75+
)
76+
else_args = [ast.arg(arg=v, annotation=None) for v in else_vars]
77+
else_def = ast.FunctionDef(
78+
name=else_name,
79+
args=ast.arguments(
80+
posonlyargs=[],
81+
args=else_args,
82+
kwonlyargs=[],
83+
kw_defaults=[],
84+
defaults=[],
85+
),
86+
body=[ast.Return(else_expr)],
87+
decorator_list=[],
88+
returns=None,
89+
)
90+
# fix locations
91+
for n in (then_def, else_def):
92+
ast.copy_location(n, node)
93+
ast.fix_missing_locations(n)
94+
# wrapper call and assignment
95+
then_args_tuple = ast.Tuple(
96+
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load()
97+
)
98+
else_args_tuple = ast.Tuple(
99+
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load()
100+
)
101+
call = ast.Call(
102+
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
103+
args=[
104+
test_node,
105+
ast.Name(id=then_name, ctx=ast.Load()),
106+
ast.Name(id=else_name, ctx=ast.Load()),
107+
then_args_tuple,
108+
else_args_tuple,
109+
],
110+
keywords=[],
111+
)
112+
assign = ast.Assign(targets=[tgt], value=call)
113+
ast.copy_location(assign, node)
114+
ast.fix_missing_locations(assign)
115+
return [then_def, else_def, assign]
116+
# Case 2: simple return in both branches
117+
if (
118+
len(node.body) == 1
119+
and isinstance(node.body[0], ast.Return)
120+
and len(node.orelse) == 1
121+
and isinstance(node.orelse[0], ast.Return)
122+
and self.current_func_args is not None
123+
):
124+
then_ret = node.body[0]
125+
else_ret = node.orelse[0]
126+
then_expr = then_ret.value
127+
else_expr = else_ret.value
128+
self.counter += 1
129+
then_name = f"{self.wrapper_name}_then_{self.counter}"
130+
else_name = f"{self.wrapper_name}_else_{self.counter}"
131+
# extract free variables
132+
then_vars = sorted(
133+
{
134+
n.id
135+
for n in ast.walk(then_expr)
136+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
137+
}
138+
)
139+
else_vars = sorted(
140+
{
141+
n.id
142+
for n in ast.walk(else_expr)
143+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
144+
}
145+
)
146+
# build local funcs
147+
then_args = [ast.arg(arg=v, annotation=None) for v in then_vars]
148+
then_def = ast.FunctionDef(
149+
name=then_name,
150+
args=ast.arguments(
151+
posonlyargs=[], args=then_args, kwonlyargs=[], kw_defaults=[], defaults=[]
152+
),
153+
body=[ast.Return(then_expr)],
154+
decorator_list=[],
155+
returns=None,
156+
)
157+
else_args = [ast.arg(arg=v, annotation=None) for v in else_vars]
158+
else_def = ast.FunctionDef(
159+
name=else_name,
160+
args=ast.arguments(
161+
posonlyargs=[], args=else_args, kwonlyargs=[], kw_defaults=[], defaults=[]
162+
),
163+
body=[ast.Return(else_expr)],
164+
decorator_list=[],
165+
returns=None,
166+
)
167+
for n in (then_def, else_def):
168+
ast.copy_location(n, node)
169+
ast.fix_missing_locations(n)
170+
# wrapper call and return
171+
then_args_tuple = ast.Tuple(
172+
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load()
173+
)
174+
else_args_tuple = ast.Tuple(
175+
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load()
176+
)
177+
call = ast.Call(
178+
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
179+
args=[
180+
test_node,
181+
ast.Name(id=then_name, ctx=ast.Load()),
182+
ast.Name(id=else_name, ctx=ast.Load()),
183+
then_args_tuple,
184+
else_args_tuple,
185+
],
186+
keywords=[],
187+
)
188+
ret = ast.Return(call)
189+
ast.copy_location(ret, node)
190+
ast.fix_missing_locations(ret)
191+
return [then_def, else_def, ret]
192+
return node
193+
194+
def generic_visit(self, node):
195+
return super().generic_visit(node)
196+
197+
198+
def _fix_missing_locations_node(node):
199+
if not hasattr(node, "lineno"):
200+
node.lineno = 999
201+
for chi in ast.iter_child_nodes(node):
202+
_fix_missing_locations_node(chi)
203+
204+
205+
def _fix_missing_locations(new_tree):
206+
for node in ast.walk(new_tree):
207+
_fix_missing_locations_node(node)
208+
209+
210+
def transform_method(func, wrapper_name="torch_cond"):
211+
"""
212+
Returns a new function based on `func` where every test (if, while, assert,
213+
ternary, comparison, boolean op) is replaced by a call to `wrapper_name`.
214+
215+
wrapper_name should refer to a function taking a single boolean argument.
216+
"""
217+
# Retrieve source of the function
218+
src = inspect.getsource(func)
219+
# Parse into AST
220+
tree = ast.parse(textwrap.dedent(src))
221+
# Apply transformation
222+
transformer = RewriteControlFlow(wrapper_name)
223+
new_tree = transformer.visit(tree)
224+
ast.fix_missing_locations(new_tree)
225+
226+
# fix other location
227+
_fix_missing_locations(new_tree)
228+
mod = compile(new_tree, filename="<ast>", mode="exec")
229+
namespace = {}
230+
exec(mod, func.__globals__, namespace)
231+
new_func = namespace.get(func.__name__)
232+
if not isinstance(new_func, types.FunctionType):
233+
raise RuntimeError("Transformed function not found")
234+
return new_tree, new_func

0 commit comments

Comments
 (0)