Skip to content

Commit 8e5f454

Browse files
committed
First test working
1 parent aff1365 commit 8e5f454

File tree

2 files changed

+105
-41
lines changed

2 files changed

+105
-41
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import ast
21
import unittest
32
import torch
43
from onnx_diagnostic.ext_test_case import ExtTestCase
@@ -7,6 +6,7 @@
76

87
class TestPatchModule(ExtTestCase):
98
def test_rewrite_forward(self):
9+
1010
class Model(torch.nn.Module):
1111
def __init__(self):
1212
super().__init__()
@@ -18,15 +18,17 @@ def forward(self, x, y):
1818
return torch.abs(x) + y
1919

2020
x, y = torch.rand((3, 4)), torch.rand((3, 4))
21+
expected = Model()(x, y)
22+
23+
rewritten = transform_method(Model.forward)
24+
Model.forward = rewritten.func
2125
Model()(x, y)
22-
tree, me = transform_method(Model.forward)
23-
24-
print("-------------")
25-
print(ast.dump(tree.body[0], indent=4))
26-
print("-------------")
27-
code = ast.unparse(tree)
28-
print(code)
29-
print("-------------")
26+
27+
DYN = torch.export.Dim.DYNAMIC
28+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
29+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
30+
got = ep.module()(x, y)
31+
self.assertEqualArray(expected, got)
3032

3133

3234
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 94 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,34 @@
22
import inspect
33
import types
44
import textwrap
5+
from typing import Callable
6+
import torch
7+
8+
NODE_TYPES = tuple(
9+
getattr(ast, k)
10+
for k in dir(ast)
11+
if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type)
12+
)
13+
14+
15+
def _settl(node, lineno, level=0):
16+
if isinstance(node, (str, int, float)):
17+
return node
18+
if isinstance(node, list):
19+
for n in node:
20+
_settl(n, lineno, level=level + 1)
21+
return node
22+
if isinstance(node, NODE_TYPES):
23+
if not hasattr(node, "lineno") or node.lineno is None:
24+
node.lineno = lineno
25+
for k in dir(node):
26+
if k in {"s", "n"}:
27+
continue
28+
if k[0] == "_":
29+
continue
30+
v = getattr(node, k)
31+
_settl(v, max(lineno, node.lineno), level=level + 1)
32+
return node
533

634

735
class RewriteControlFlow(ast.NodeTransformer):
@@ -22,6 +50,7 @@ def visit_If(self, node):
2250
# First recurse into subnodes
2351
node = self.generic_visit(node)
2452
test_node = node.test
53+
2554
# Case 1: simple assignment in both branches
2655
if (
2756
len(node.body) == 1
@@ -91,12 +120,15 @@ def visit_If(self, node):
91120
for n in (then_def, else_def):
92121
ast.copy_location(n, node)
93122
ast.fix_missing_locations(n)
123+
assert hasattr(n, "lineno")
94124
# wrapper call and assignment
95125
then_args_tuple = ast.Tuple(
96-
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load()
126+
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars],
127+
ctx=ast.Load(),
97128
)
98129
else_args_tuple = ast.Tuple(
99-
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load()
130+
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars],
131+
ctx=ast.Load(),
100132
)
101133
call = ast.Call(
102134
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
@@ -113,6 +145,7 @@ def visit_If(self, node):
113145
ast.copy_location(assign, node)
114146
ast.fix_missing_locations(assign)
115147
return [then_def, else_def, assign]
148+
116149
# Case 2: simple return in both branches
117150
if (
118151
len(node.body) == 1
@@ -143,22 +176,33 @@ def visit_If(self, node):
143176
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
144177
}
145178
)
179+
180+
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
181+
146182
# build local funcs
147-
then_args = [ast.arg(arg=v, annotation=None) for v in then_vars]
183+
then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
148184
then_def = ast.FunctionDef(
149185
name=then_name,
150186
args=ast.arguments(
151-
posonlyargs=[], args=then_args, kwonlyargs=[], kw_defaults=[], defaults=[]
187+
posonlyargs=[],
188+
args=then_args,
189+
kwonlyargs=[],
190+
kw_defaults=[],
191+
defaults=[],
152192
),
153193
body=[ast.Return(then_expr)],
154194
decorator_list=[],
155195
returns=None,
156196
)
157-
else_args = [ast.arg(arg=v, annotation=None) for v in else_vars]
197+
else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
158198
else_def = ast.FunctionDef(
159199
name=else_name,
160200
args=ast.arguments(
161-
posonlyargs=[], args=else_args, kwonlyargs=[], kw_defaults=[], defaults=[]
201+
posonlyargs=[],
202+
args=else_args,
203+
kwonlyargs=[],
204+
kw_defaults=[],
205+
defaults=[],
162206
),
163207
body=[ast.Return(else_expr)],
164208
decorator_list=[],
@@ -168,20 +212,18 @@ def visit_If(self, node):
168212
ast.copy_location(n, node)
169213
ast.fix_missing_locations(n)
170214
# wrapper call and return
171-
then_args_tuple = ast.Tuple(
172-
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars], ctx=ast.Load()
173-
)
174-
else_args_tuple = ast.Tuple(
175-
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars], ctx=ast.Load()
215+
then_else_args_list = ast.List(
216+
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
217+
ctx=ast.Load(),
176218
)
219+
177220
call = ast.Call(
178221
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
179222
args=[
180223
test_node,
181224
ast.Name(id=then_name, ctx=ast.Load()),
182225
ast.Name(id=else_name, ctx=ast.Load()),
183-
then_args_tuple,
184-
else_args_tuple,
226+
then_else_args_list,
185227
],
186228
keywords=[],
187229
)
@@ -195,24 +237,29 @@ def generic_visit(self, node):
195237
return super().generic_visit(node)
196238

197239

198-
def _fix_missing_locations_node(node):
199-
if not hasattr(node, "lineno"):
200-
node.lineno = 999
201-
for chi in ast.iter_child_nodes(node):
202-
_fix_missing_locations_node(chi)
240+
class RewrittenMethod:
241+
"""
242+
Stores a rewritten method using
243+
:func:`onnx_diagnostic.torch_export_patches.path_module.transform_method>`.
244+
"""
245+
246+
def __init__(self, tree, func):
247+
self.tree = tree
248+
self.func = func
203249

250+
@property
251+
def code(self) -> str:
252+
"""Returns the source."""
253+
return ast.unparse(self.tree)
204254

205-
def _fix_missing_locations(new_tree):
206-
for node in ast.walk(new_tree):
207-
_fix_missing_locations_node(node)
255+
def __repr__(self):
256+
return f"{self.__class__.__name__}({self.func})"
208257

209258

210-
def transform_method(func, wrapper_name="torch_cond"):
259+
def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMethod:
211260
"""
212-
Returns a new function based on `func` where every test (if, while, assert,
213-
ternary, comparison, boolean op) is replaced by a call to `wrapper_name`.
214-
215-
wrapper_name should refer to a function taking a single boolean argument.
261+
Returns a new function based on `func` where every test (if)
262+
is replaced by a call to :func:`torch.cond`.
216263
"""
217264
# Retrieve source of the function
218265
src = inspect.getsource(func)
@@ -222,13 +269,28 @@ def transform_method(func, wrapper_name="torch_cond"):
222269
transformer = RewriteControlFlow(wrapper_name)
223270
new_tree = transformer.visit(tree)
224271
ast.fix_missing_locations(new_tree)
225-
226-
# fix other location
227-
_fix_missing_locations(new_tree)
228-
mod = compile(new_tree, filename="<ast>", mode="exec")
272+
_settl(new_tree, 0)
273+
try:
274+
mod = compile(new_tree, filename="<ast>", mode="exec")
275+
except TypeError as e:
276+
if 'required field "lineno" missing from stmt' in str(e):
277+
# Could not find a way to avoid compilng a string.
278+
# The error message still pops up without indicating which node is not
279+
# properly set.
280+
code = ast.unparse(new_tree)
281+
mod = compile(code, filename="<source>", mode="exec")
282+
else:
283+
kws = dict(include_attributes=True, annotate_fields=True, indent=4)
284+
raise RuntimeError(
285+
f"Unable to compile code\n--CODE--\n"
286+
f"{ast.unparse(new_tree)}\n--TREE--\n"
287+
f"{ast.dump(new_tree, **kws)}"
288+
) from e
229289
namespace = {}
230-
exec(mod, func.__globals__, namespace)
290+
globs = func.__globals__.copy()
291+
globs["torch_cond"] = torch.cond
292+
exec(mod, globs, namespace)
231293
new_func = namespace.get(func.__name__)
232294
if not isinstance(new_func, types.FunctionType):
233295
raise RuntimeError("Transformed function not found")
234-
return new_tree, new_func
296+
return RewrittenMethod(new_tree, new_func)

0 commit comments

Comments
 (0)