Skip to content

Commit 2953e47

Browse files
committed
refactor the code
1 parent 94d9508 commit 2953e47

File tree

2 files changed

+121
-142
lines changed

2 files changed

+121
-142
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class TestPatchModule(ExtTestCase):
8-
def test_rewrite_forward(self):
8+
def test_rewrite_forward_return(self):
99

1010
class Model(torch.nn.Module):
1111
def __init__(self):
@@ -30,6 +30,32 @@ def forward(self, x, y):
3030
got = ep.module()(x, y)
3131
self.assertEqualArray(expected, got)
3232

33+
def test_rewrite_forward_assign(self):
34+
35+
class Model(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
39+
def forward(self, x, y):
40+
if x.sum() > 0:
41+
z = x + y
42+
else:
43+
z = torch.abs(x) + y
44+
return z
45+
46+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
47+
expected = Model()(x, y)
48+
49+
rewritten = transform_method(Model.forward)
50+
Model.forward = rewritten.func
51+
Model()(x, y)
52+
53+
DYN = torch.export.Dim.DYNAMIC
54+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
55+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
56+
got = ep.module()(x, y)
57+
self.assertEqualArray(expected, got)
58+
3359

3460
if __name__ == "__main__":
3561
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 94 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,83 @@ def visit_FunctionDef(self, node):
4646
self.current_func_args = old_args
4747
return node
4848

49+
def _rewrite_if(self, node, then_expr, else_expr):
50+
test_node = node.test
51+
52+
# extract free variables
53+
then_name = f"{self.wrapper_name}_then_{self.counter}"
54+
else_name = f"{self.wrapper_name}_else_{self.counter}"
55+
then_vars = sorted(
56+
{
57+
n.id
58+
for n in ast.walk(then_expr)
59+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
60+
}
61+
)
62+
else_vars = sorted(
63+
{
64+
n.id
65+
for n in ast.walk(else_expr)
66+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
67+
}
68+
)
69+
70+
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
71+
72+
# build local funcs
73+
then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
74+
then_def = ast.FunctionDef(
75+
name=then_name,
76+
args=ast.arguments(
77+
posonlyargs=[],
78+
args=then_args,
79+
kwonlyargs=[],
80+
kw_defaults=[],
81+
defaults=[],
82+
),
83+
body=[ast.Return(then_expr)],
84+
decorator_list=[],
85+
returns=None,
86+
)
87+
else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
88+
else_def = ast.FunctionDef(
89+
name=else_name,
90+
args=ast.arguments(
91+
posonlyargs=[],
92+
args=else_args,
93+
kwonlyargs=[],
94+
kw_defaults=[],
95+
defaults=[],
96+
),
97+
body=[ast.Return(else_expr)],
98+
decorator_list=[],
99+
returns=None,
100+
)
101+
# fix locations
102+
for n in (then_def, else_def):
103+
ast.copy_location(n, node)
104+
ast.fix_missing_locations(n)
105+
assert hasattr(n, "lineno")
106+
# wrapper call and assignment
107+
then_else_args_list = ast.List(
108+
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
109+
ctx=ast.Load(),
110+
)
111+
call = ast.Call(
112+
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
113+
args=[
114+
test_node,
115+
ast.Name(id=then_name, ctx=ast.Load()),
116+
ast.Name(id=else_name, ctx=ast.Load()),
117+
then_else_args_list,
118+
],
119+
keywords=[],
120+
)
121+
return then_def, else_def, call
122+
49123
def visit_If(self, node):
50124
# First recurse into subnodes
51125
node = self.generic_visit(node)
52-
test_node = node.test
53126

54127
# Case 1: simple assignment in both branches
55128
if (
@@ -68,79 +141,9 @@ def visit_If(self, node):
68141
and tgt.id == else_assign.targets[0].id
69142
):
70143
self.counter += 1
71-
then_name = f"{self.wrapper_name}_then_{self.counter}"
72-
else_name = f"{self.wrapper_name}_else_{self.counter}"
73144
then_expr = then_assign.value
74145
else_expr = else_assign.value
75-
# extract free variables
76-
then_vars = sorted(
77-
{
78-
n.id
79-
for n in ast.walk(then_expr)
80-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
81-
}
82-
)
83-
else_vars = sorted(
84-
{
85-
n.id
86-
for n in ast.walk(else_expr)
87-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
88-
}
89-
)
90-
# build local funcs
91-
then_args = [ast.arg(arg=v, annotation=None) for v in then_vars]
92-
then_def = ast.FunctionDef(
93-
name=then_name,
94-
args=ast.arguments(
95-
posonlyargs=[],
96-
args=then_args,
97-
kwonlyargs=[],
98-
kw_defaults=[],
99-
defaults=[],
100-
),
101-
body=[ast.Return(then_expr)],
102-
decorator_list=[],
103-
returns=None,
104-
)
105-
else_args = [ast.arg(arg=v, annotation=None) for v in else_vars]
106-
else_def = ast.FunctionDef(
107-
name=else_name,
108-
args=ast.arguments(
109-
posonlyargs=[],
110-
args=else_args,
111-
kwonlyargs=[],
112-
kw_defaults=[],
113-
defaults=[],
114-
),
115-
body=[ast.Return(else_expr)],
116-
decorator_list=[],
117-
returns=None,
118-
)
119-
# fix locations
120-
for n in (then_def, else_def):
121-
ast.copy_location(n, node)
122-
ast.fix_missing_locations(n)
123-
assert hasattr(n, "lineno")
124-
# wrapper call and assignment
125-
then_args_tuple = ast.Tuple(
126-
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars],
127-
ctx=ast.Load(),
128-
)
129-
else_args_tuple = ast.Tuple(
130-
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars],
131-
ctx=ast.Load(),
132-
)
133-
call = ast.Call(
134-
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
135-
args=[
136-
test_node,
137-
ast.Name(id=then_name, ctx=ast.Load()),
138-
ast.Name(id=else_name, ctx=ast.Load()),
139-
then_args_tuple,
140-
else_args_tuple,
141-
],
142-
keywords=[],
143-
)
146+
then_def, else_def, call = self._rewrite_if(node, then_expr, else_expr)
144147
assign = ast.Assign(targets=[tgt], value=call)
145148
ast.copy_location(assign, node)
146149
ast.fix_missing_locations(assign)
@@ -159,74 +162,7 @@ def visit_If(self, node):
159162
then_expr = then_ret.value
160163
else_expr = else_ret.value
161164
self.counter += 1
162-
then_name = f"{self.wrapper_name}_then_{self.counter}"
163-
else_name = f"{self.wrapper_name}_else_{self.counter}"
164-
# extract free variables
165-
then_vars = sorted(
166-
{
167-
n.id
168-
for n in ast.walk(then_expr)
169-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
170-
}
171-
)
172-
else_vars = sorted(
173-
{
174-
n.id
175-
for n in ast.walk(else_expr)
176-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
177-
}
178-
)
179-
180-
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
181-
182-
# build local funcs
183-
then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
184-
then_def = ast.FunctionDef(
185-
name=then_name,
186-
args=ast.arguments(
187-
posonlyargs=[],
188-
args=then_args,
189-
kwonlyargs=[],
190-
kw_defaults=[],
191-
defaults=[],
192-
),
193-
body=[ast.Return(then_expr)],
194-
decorator_list=[],
195-
returns=None,
196-
)
197-
else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
198-
else_def = ast.FunctionDef(
199-
name=else_name,
200-
args=ast.arguments(
201-
posonlyargs=[],
202-
args=else_args,
203-
kwonlyargs=[],
204-
kw_defaults=[],
205-
defaults=[],
206-
),
207-
body=[ast.Return(else_expr)],
208-
decorator_list=[],
209-
returns=None,
210-
)
211-
for n in (then_def, else_def):
212-
ast.copy_location(n, node)
213-
ast.fix_missing_locations(n)
214-
# wrapper call and return
215-
then_else_args_list = ast.List(
216-
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
217-
ctx=ast.Load(),
218-
)
219-
220-
call = ast.Call(
221-
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
222-
args=[
223-
test_node,
224-
ast.Name(id=then_name, ctx=ast.Load()),
225-
ast.Name(id=else_name, ctx=ast.Load()),
226-
then_else_args_list,
227-
],
228-
keywords=[],
229-
)
165+
then_def, else_def, call = self._rewrite_if(node, then_expr, else_expr)
230166
ret = ast.Return(call)
231167
ast.copy_location(ret, node)
232168
ast.fix_missing_locations(ret)
@@ -260,24 +196,41 @@ def __repr__(self):
260196
return f"{self.__class__.__name__}({self.func})"
261197

262198

263-
def transform_method(func: Callable, if_name="torch_cond") -> RewrittenMethod:
199+
def transform_method(
200+
func: Callable, if_name="torch_cond", verbose: int = 0
201+
) -> RewrittenMethod:
264202
"""
265203
Returns a new function based on `func` where every test (if)
266204
is replaced by a call to :func:`torch.cond`.
267205
268206
:param func: method or function to rewrite
269207
:param if_name: function calling the test
208+
:param verbose: verbosity
270209
:return: rewritten method
271210
"""
272211
# Retrieve source of the function
273212
src = inspect.getsource(func)
213+
if verbose:
214+
print(f"[transform_method] -- source -- {func}")
215+
print(src)
274216
# Parse into AST
275217
tree = ast.parse(textwrap.dedent(src))
218+
if verbose > 1:
219+
print("[transform_method] -- tree --")
220+
print(ast.dump(tree, indent=2))
276221
# Apply transformation
277222
transformer = RewriteControlFlow(if_name)
278223
new_tree = transformer.visit(tree)
224+
if verbose > 1:
225+
print("[transform_method] -- new tree --")
226+
print(ast.dump(tree, indent=2))
279227
ast.fix_missing_locations(new_tree)
280228
_settl(new_tree, 0)
229+
230+
if verbose > 0:
231+
print("[transform_method] -- new code --")
232+
code = ast.unparse(new_tree)
233+
print(code)
281234
try:
282235
mod = compile(new_tree, filename="<ast>", mode="exec")
283236
except TypeError as e:

0 commit comments

Comments
 (0)