Skip to content

Commit 86a410f

Browse files
committed
fix a couple of cases
1 parent 2953e47 commit 86a410f

File tree

2 files changed

+195
-76
lines changed

2 files changed

+195
-76
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
44
from onnx_diagnostic.torch_export_patches.patch_module import transform_method
55

66

77
class TestPatchModule(ExtTestCase):
8-
def test_rewrite_forward_return(self):
8+
def test_rewrite_forward_return1(self):
99

1010
class Model(torch.nn.Module):
1111
def __init__(self):
@@ -27,10 +27,39 @@ def forward(self, x, y):
2727
DYN = torch.export.Dim.DYNAMIC
2828
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
2929
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
30+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
3031
got = ep.module()(x, y)
3132
self.assertEqualArray(expected, got)
3233

33-
def test_rewrite_forward_assign(self):
34+
@hide_stdout()
35+
def test_rewrite_forward_return2(self):
36+
37+
class Model(torch.nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
41+
def forward(self, x, y):
42+
if x.sum() > 0:
43+
return x + y, x - y
44+
else:
45+
return torch.abs(x) + y, torch.abs(x) - y
46+
47+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
48+
expected = Model()(x, y)
49+
50+
rewritten = transform_method(Model.forward, verbose=10)
51+
Model.forward = rewritten.func
52+
Model()(x, y)
53+
54+
DYN = torch.export.Dim.DYNAMIC
55+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
56+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
57+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
58+
got = ep.module()(x, y)
59+
self.assertEqualAny(expected, got)
60+
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
61+
62+
def test_rewrite_forward_assign1(self):
3463

3564
class Model(torch.nn.Module):
3665
def __init__(self):
@@ -46,15 +75,72 @@ def forward(self, x, y):
4675
x, y = torch.rand((3, 4)), torch.rand((3, 4))
4776
expected = Model()(x, y)
4877

49-
rewritten = transform_method(Model.forward)
78+
rewritten = transform_method(Model.forward, verbose=0)
5079
Model.forward = rewritten.func
5180
Model()(x, y)
5281

5382
DYN = torch.export.Dim.DYNAMIC
5483
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
5584
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
85+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
5686
got = ep.module()(x, y)
5787
self.assertEqualArray(expected, got)
88+
self.assertEqualArray(Model()(-x, y), ep.module()(-x, y))
89+
90+
def test_rewrite_forward_assign2(self):
91+
92+
class Model(torch.nn.Module):
93+
def __init__(self):
94+
super().__init__()
95+
96+
def forward(self, x, y):
97+
if x.sum() > 0:
98+
w, z = x + y, x - y
99+
else:
100+
w, z = torch.abs(x) + y, torch.abs(x) - y
101+
return w, z
102+
103+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
104+
expected = Model()(x, y)
105+
106+
rewritten = transform_method(Model.forward, verbose=0)
107+
Model.forward = rewritten.func
108+
Model()(x, y)
109+
110+
DYN = torch.export.Dim.DYNAMIC
111+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
112+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
113+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
114+
got = ep.module()(x, y)
115+
self.assertEqualAny(expected, got)
116+
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
117+
118+
def test_rewrite_forward_noelse(self):
119+
120+
class Model(torch.nn.Module):
121+
def __init__(self):
122+
super().__init__()
123+
124+
def forward(self, x, y):
125+
if x.sum() > 0:
126+
x = torch.abs(x)
127+
return x + y
128+
129+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
130+
expected = Model()(x, y)
131+
132+
rewritten = transform_method(Model.forward, verbose=0)
133+
Model.forward = rewritten.func
134+
Model()(x, y)
135+
136+
DYN = torch.export.Dim.DYNAMIC
137+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
138+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
139+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
140+
got = ep.module()(x, y)
141+
self.assertEqualAny(expected, got)
142+
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
143+
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
58144

59145

60146
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 105 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -46,36 +46,45 @@ def visit_FunctionDef(self, node):
4646
self.current_func_args = old_args
4747
return node
4848

49-
def _rewrite_if(self, node, then_expr, else_expr):
49+
def _find_id(self, exprs):
50+
vars = []
51+
for expr in exprs:
52+
for n in ast.walk(expr):
53+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
54+
vars.append(n.id)
55+
return sorted(set(vars))
56+
57+
def _rewrite_if(self, node, then_exprs, else_exprs, tgt=None):
5058
test_node = node.test
5159

5260
# extract free variables
5361
then_name = f"{self.wrapper_name}_then_{self.counter}"
5462
else_name = f"{self.wrapper_name}_else_{self.counter}"
55-
then_vars = sorted(
56-
{
57-
n.id
58-
for n in ast.walk(then_expr)
59-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
60-
}
61-
)
62-
else_vars = sorted(
63-
{
64-
n.id
65-
for n in ast.walk(else_expr)
66-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
67-
}
68-
)
69-
63+
then_vars = self._find_id(then_exprs)
64+
else_vars = self._find_id(else_exprs)
7065
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
66+
then_expr, else_expr = None, None
67+
if tgt is None and len(then_exprs) == 1 and len(else_exprs) == 1:
68+
# return
69+
then_expr = then_exprs[0]
70+
else_expr = else_exprs[0]
71+
elif len(then_exprs) == 1 and len(else_exprs) == 1:
72+
# assignment but only one, so we assume it is the same
73+
then_expr = then_exprs[0]
74+
else_expr = else_exprs[0]
75+
else:
76+
raise NotImplementedError(
77+
f"Unable to rewrite node {node}, len(then_exprs)={len(then_exprs)}, "
78+
f"len(else_exprs)={len(else_exprs)}, "
79+
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
80+
)
7181

7282
# build local funcs
73-
then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
7483
then_def = ast.FunctionDef(
7584
name=then_name,
7685
args=ast.arguments(
7786
posonlyargs=[],
78-
args=then_args,
87+
args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
7988
kwonlyargs=[],
8089
kw_defaults=[],
8190
defaults=[],
@@ -84,12 +93,11 @@ def _rewrite_if(self, node, then_expr, else_expr):
8493
decorator_list=[],
8594
returns=None,
8695
)
87-
else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
8896
else_def = ast.FunctionDef(
8997
name=else_name,
9098
args=ast.arguments(
9199
posonlyargs=[],
92-
args=else_args,
100+
args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
93101
kwonlyargs=[],
94102
kw_defaults=[],
95103
defaults=[],
@@ -124,50 +132,77 @@ def visit_If(self, node):
124132
# First recurse into subnodes
125133
node = self.generic_visit(node)
126134

127-
# Case 1: simple assignment in both branches
128-
if (
129-
len(node.body) == 1
130-
and isinstance(node.body[0], ast.Assign)
131-
and len(node.orelse) == 1
132-
and isinstance(node.orelse[0], ast.Assign)
133-
and self.current_func_args is not None
134-
):
135-
then_assign = node.body[0]
136-
else_assign = node.orelse[0]
137-
tgt = then_assign.targets[0]
138-
if (
139-
isinstance(tgt, ast.Name)
140-
and isinstance(else_assign.targets[0], ast.Name)
141-
and tgt.id == else_assign.targets[0].id
142-
):
143-
self.counter += 1
144-
then_expr = then_assign.value
145-
else_expr = else_assign.value
146-
then_def, else_def, call = self._rewrite_if(node, then_expr, else_expr)
147-
assign = ast.Assign(targets=[tgt], value=call)
148-
ast.copy_location(assign, node)
149-
ast.fix_missing_locations(assign)
150-
return [then_def, else_def, assign]
135+
has_then_return = any(isinstance(n, ast.Return) for n in node.body)
136+
has_else_return = any(isinstance(n, ast.Return) for n in node.orelse)
137+
ok = (has_then_return and has_else_return) or (
138+
not has_then_return and not has_else_return
139+
)
140+
assert ok, (
141+
f"Cannot mix return and assignment in a test\n--\n"
142+
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
143+
)
144+
assert self.current_func_args is not None, (
145+
f"current_func_args is None\n--\n"
146+
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
147+
)
148+
self.counter += 1
151149

152-
# Case 2: simple return in both branches
153-
if (
154-
len(node.body) == 1
155-
and isinstance(node.body[0], ast.Return)
156-
and len(node.orelse) == 1
157-
and isinstance(node.orelse[0], ast.Return)
158-
and self.current_func_args is not None
159-
):
160-
then_ret = node.body[0]
161-
else_ret = node.orelse[0]
162-
then_expr = then_ret.value
163-
else_expr = else_ret.value
164-
self.counter += 1
165-
then_def, else_def, call = self._rewrite_if(node, then_expr, else_expr)
166-
ret = ast.Return(call)
167-
ast.copy_location(ret, node)
168-
ast.fix_missing_locations(ret)
169-
return [then_def, else_def, ret]
170-
return node
150+
if not has_then_return:
151+
# Case 1: simple assignment in both branches
152+
then_assigns = [n for n in node.body if isinstance(n, ast.Assign)]
153+
else_assigns = [n for n in node.orelse if isinstance(n, ast.Assign)]
154+
assert then_assigns or else_assigns, (
155+
f"Missing assignment\n--\n"
156+
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
157+
)
158+
159+
targets = []
160+
for a in [*then_assigns, *else_assigns]:
161+
for t in a.targets:
162+
if isinstance(t, ast.Name):
163+
targets.append((t.id, t))
164+
continue
165+
166+
assert isinstance(t, ast.Tuple) and all(
167+
isinstance(_, ast.Name) for _ in t.elts
168+
), (
169+
f"Unexpected assignment. Not Supported."
170+
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
171+
)
172+
targets.extend((_.id, _) for _ in t.elts)
173+
174+
d = [_[1] for _ in sorted(dict(targets).items())]
175+
tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
176+
177+
then_values = [n.value for n in then_assigns]
178+
else_values = [n.value for n in else_assigns]
179+
then_def, else_def, call = self._rewrite_if(
180+
node, then_values, else_values, tgt=tgt
181+
)
182+
183+
assign = ast.Assign(targets=[tgt], value=call)
184+
ast.copy_location(assign, node)
185+
ast.fix_missing_locations(assign)
186+
return [then_def, else_def, assign]
187+
188+
# Case 2: return in both branches, we assume both branches return the same results.
189+
then_ret = node.body[-1]
190+
else_ret = node.orelse[-1]
191+
assert isinstance(then_ret, ast.Return), (
192+
f"return is not the last instruction of then branch"
193+
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
194+
)
195+
assert isinstance(else_ret, ast.Return), (
196+
f"return is not the last instruction of else branch"
197+
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
198+
)
199+
then_expr = then_ret.value
200+
else_expr = else_ret.value
201+
then_def, else_def, call = self._rewrite_if(node, [then_expr], [else_expr])
202+
ret = ast.Return(call)
203+
ast.copy_location(ret, node)
204+
ast.fix_missing_locations(ret)
205+
return [then_def, else_def, ret]
171206

172207
def generic_visit(self, node):
173208
return super().generic_visit(node)
@@ -211,26 +246,24 @@ def transform_method(
211246
# Retrieve source of the function
212247
src = inspect.getsource(func)
213248
if verbose:
214-
print(f"[transform_method] -- source -- {func}")
215-
print(src)
249+
print(f"[transform_method] -- source -- {func}\n\n{src}\n\n[transform_method] --")
216250
# Parse into AST
217251
tree = ast.parse(textwrap.dedent(src))
218252
if verbose > 1:
219-
print("[transform_method] -- tree --")
220-
print(ast.dump(tree, indent=2))
253+
print(f"[transform_method] -- tree --\n\n{ast.dump(tree, indent=2)}")
221254
# Apply transformation
222255
transformer = RewriteControlFlow(if_name)
223256
new_tree = transformer.visit(tree)
224257
if verbose > 1:
225-
print("[transform_method] -- new tree --")
226-
print(ast.dump(tree, indent=2))
258+
print(f"[transform_method] -- new tree --\n\n{ast.dump(tree, indent=2)}")
227259
ast.fix_missing_locations(new_tree)
228260
_settl(new_tree, 0)
229261

230262
if verbose > 0:
231-
print("[transform_method] -- new code --")
232-
code = ast.unparse(new_tree)
233-
print(code)
263+
print(
264+
f"[transform_method] -- new code --\n\n"
265+
f"{ast.unparse(new_tree)}\n\n[transform_method] --"
266+
)
234267
try:
235268
mod = compile(new_tree, filename="<ast>", mode="exec")
236269
except TypeError as e:

0 commit comments

Comments
 (0)