Skip to content

Commit 105cbe6

Browse files
committed
issues
1 parent 86a410f commit 105cbe6

File tree

2 files changed

+153
-71
lines changed

2 files changed

+153
-71
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,139 +8,134 @@ class TestPatchModule(ExtTestCase):
88
def test_rewrite_forward_return1(self):
99

1010
class Model(torch.nn.Module):
11-
def __init__(self):
12-
super().__init__()
13-
1411
def forward(self, x, y):
1512
if x.sum() > 0:
1613
return x + y
1714
else:
18-
return torch.abs(x) + y
15+
return torch.abs(x) + y + 1
1916

2017
x, y = torch.rand((3, 4)), torch.rand((3, 4))
21-
expected = Model()(x, y)
18+
expected, expected_ = Model()(x, y), Model()(-x, y)
2219

2320
rewritten = transform_method(Model.forward)
21+
self.assertIn("torch.abs(", rewritten.code)
22+
self.assertIn("'abs'", rewritten.dump)
2423
Model.forward = rewritten.func
25-
Model()(x, y)
24+
self.assertEqualAny(expected, Model()(x, y))
25+
self.assertEqualAny(expected_, Model()(-x, y))
2626

2727
DYN = torch.export.Dim.DYNAMIC
2828
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
2929
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
3030
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
31-
got = ep.module()(x, y)
32-
self.assertEqualArray(expected, got)
31+
self.assertEqualAny(expected, ep.module()(x, y))
32+
self.assertEqualAny(expected_, ep.module()(-x, y))
3333

3434
@hide_stdout()
3535
def test_rewrite_forward_return2(self):
3636

3737
class Model(torch.nn.Module):
38-
def __init__(self):
39-
super().__init__()
40-
4138
def forward(self, x, y):
4239
if x.sum() > 0:
4340
return x + y, x - y
4441
else:
45-
return torch.abs(x) + y, torch.abs(x) - y
42+
return torch.abs(x) + y + 1, torch.abs(x) - y + 1
4643

4744
x, y = torch.rand((3, 4)), torch.rand((3, 4))
48-
expected = Model()(x, y)
45+
expected, expected_ = Model()(x, y), Model()(-x, y)
4946

5047
rewritten = transform_method(Model.forward, verbose=10)
48+
self.assertIn("torch.abs(", rewritten.code)
49+
self.assertIn("abs", rewritten.dump)
5150
Model.forward = rewritten.func
52-
Model()(x, y)
51+
self.assertEqualAny(expected, Model()(x, y))
52+
self.assertEqualAny(expected_, Model()(-x, y))
5353

5454
DYN = torch.export.Dim.DYNAMIC
5555
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
5656
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
5757
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
58-
got = ep.module()(x, y)
59-
self.assertEqualAny(expected, got)
60-
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
58+
self.assertEqualAny(expected, ep.module()(x, y))
59+
self.assertEqualAny(expected_, ep.module()(-x, y))
6160

6261
def test_rewrite_forward_assign1(self):
6362

6463
class Model(torch.nn.Module):
65-
def __init__(self):
66-
super().__init__()
67-
6864
def forward(self, x, y):
6965
if x.sum() > 0:
7066
z = x + y
7167
else:
72-
z = torch.abs(x) + y
68+
z = torch.abs(x) + y + 1
7369
return z
7470

7571
x, y = torch.rand((3, 4)), torch.rand((3, 4))
76-
expected = Model()(x, y)
72+
expected, expected_ = Model()(x, y), Model()(-x, y)
7773

7874
rewritten = transform_method(Model.forward, verbose=0)
75+
self.assertIn("torch.abs(", rewritten.code)
76+
self.assertIn("abs", rewritten.dump)
7977
Model.forward = rewritten.func
80-
Model()(x, y)
78+
self.assertEqualAny(expected, Model()(x, y))
79+
self.assertEqualAny(expected_, Model()(-x, y))
8180

8281
DYN = torch.export.Dim.DYNAMIC
8382
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
8483
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
8584
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
86-
got = ep.module()(x, y)
87-
self.assertEqualArray(expected, got)
88-
self.assertEqualArray(Model()(-x, y), ep.module()(-x, y))
85+
self.assertEqualAny(expected, ep.module()(x, y))
86+
self.assertEqualArray(expected_, ep.module()(-x, y))
8987

9088
def test_rewrite_forward_assign2(self):
9189

9290
class Model(torch.nn.Module):
93-
def __init__(self):
94-
super().__init__()
95-
9691
def forward(self, x, y):
9792
if x.sum() > 0:
9893
w, z = x + y, x - y
9994
else:
100-
w, z = torch.abs(x) + y, torch.abs(x) - y
95+
w, z = torch.abs(x) + y + 1, torch.abs(x) - y + 1
10196
return w, z
10297

10398
x, y = torch.rand((3, 4)), torch.rand((3, 4))
104-
expected = Model()(x, y)
99+
expected, expected_ = Model()(x, y), Model()(-x, y)
105100

106101
rewritten = transform_method(Model.forward, verbose=0)
102+
self.assertIn("torch.abs(", rewritten.code)
103+
self.assertIn("abs", rewritten.dump)
107104
Model.forward = rewritten.func
108-
Model()(x, y)
105+
self.assertEqualAny(expected, Model()(x, y))
106+
self.assertEqualAny(expected_, Model()(-x, y))
109107

110108
DYN = torch.export.Dim.DYNAMIC
111109
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
112110
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
113111
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
114-
got = ep.module()(x, y)
115-
self.assertEqualAny(expected, got)
116-
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
112+
self.assertEqualAny(expected, ep.module()(x, y))
113+
self.assertEqualAny(expected_, ep.module()(-x, y))
117114

118115
def test_rewrite_forward_noelse(self):
119116

120117
class Model(torch.nn.Module):
121-
def __init__(self):
122-
super().__init__()
123-
124118
def forward(self, x, y):
125119
if x.sum() > 0:
126-
x = torch.abs(x)
120+
x = torch.abs(x) + 1
127121
return x + y
128122

129123
x, y = torch.rand((3, 4)), torch.rand((3, 4))
130-
expected = Model()(x, y)
124+
expected, expected_ = Model()(x, y), Model()(-x, y)
131125

132126
rewritten = transform_method(Model.forward, verbose=0)
127+
self.assertIn("torch.abs(", rewritten.code)
128+
self.assertIn("abs", rewritten.dump)
133129
Model.forward = rewritten.func
134-
Model()(x, y)
130+
self.assertEqualAny(expected, Model()(x, y))
131+
self.assertEqualAny(expected_, Model()(-x, y))
135132

136133
DYN = torch.export.Dim.DYNAMIC
137134
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
138135
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
139136
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
140-
got = ep.module()(x, y)
141-
self.assertEqualAny(expected, got)
142-
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
143-
self.assertEqualAny(Model()(-x, y), ep.module()(-x, y))
137+
self.assertEqualAny(expected, ep.module()(x, y))
138+
self.assertEqualAny(expected_, ep.module()(-x, y))
144139

145140

146141
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 113 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def _settl(node, lineno, level=0):
3333

3434

3535
class RewriteControlFlow(ast.NodeTransformer):
36+
"""
37+
The class rewrites tests with function ``torch_cond`` :func:`torch.cond`.
38+
"""
39+
3640
def __init__(self, wrapper_name):
3741
self.wrapper_name = wrapper_name
3842
self.counter = 0
@@ -54,7 +58,7 @@ def _find_id(self, exprs):
5458
vars.append(n.id)
5559
return sorted(set(vars))
5660

57-
def _rewrite_if(self, node, then_exprs, else_exprs, tgt=None):
61+
def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
5862
test_node = node.test
5963

6064
# extract free variables
@@ -64,20 +68,42 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt=None):
6468
else_vars = self._find_id(else_exprs)
6569
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
6670
then_expr, else_expr = None, None
67-
if tgt is None and len(then_exprs) == 1 and len(else_exprs) == 1:
71+
if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
6872
# return
6973
then_expr = then_exprs[0]
7074
else_expr = else_exprs[0]
71-
elif len(then_exprs) == 1 and len(else_exprs) == 1:
72-
# assignment but only one, so we assume it is the same
75+
elif (
76+
tgt_mapping
77+
and len(then_exprs) == 1
78+
and len(else_exprs) == 1
79+
and len(tgt_mapping) == 1
80+
):
81+
# assignment but only one
7382
then_expr = then_exprs[0]
7483
else_expr = else_exprs[0]
7584
else:
76-
raise NotImplementedError(
77-
f"Unable to rewrite node {node}, len(then_exprs)={len(then_exprs)}, "
78-
f"len(else_exprs)={len(else_exprs)}, "
85+
assert tgt_mapping, (
86+
f"then and else branchs do not have the same number "
87+
f"of assignments, we need more information to understand "
88+
f"which ones to return,"
7989
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
8090
)
91+
then_exprs = []
92+
else_exprs = []
93+
for t in tgt_mapping:
94+
then_e, else_e = tgt_mapping[t]
95+
then_exprs.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
96+
else_exprs.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
97+
then_expr = (
98+
then_exprs[0]
99+
if len(then_exprs) == 1
100+
else ast.Tuple(then_exprs, ctx=ast.Load())
101+
)
102+
else_expr = (
103+
else_exprs[0]
104+
if len(else_exprs) == 1
105+
else ast.Tuple(else_exprs, ctx=ast.Load())
106+
)
81107

82108
# build local funcs
83109
then_def = ast.FunctionDef(
@@ -128,6 +154,38 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt=None):
128154
)
129155
return then_def, else_def, call
130156

157+
def _make_targets(self, node, then_assigns, else_assigns):
158+
tgt_mapping = {}
159+
for a, then_or_else in [
160+
*[(a, True) for a in then_assigns],
161+
*[(a, False) for a in else_assigns],
162+
]:
163+
for t in a.targets:
164+
if isinstance(t, ast.Name):
165+
if t.id not in tgt_mapping:
166+
tgt_mapping[t.id] = (t, None) if then_or_else else (None, t)
167+
else:
168+
v = tgt_mapping[t.id]
169+
tgt_mapping[t.id] = (t, v[1]) if then_or_else else (v[0], t)
170+
continue
171+
172+
assert isinstance(t, ast.Tuple) and all(
173+
isinstance(_, ast.Name) for _ in t.elts
174+
), (
175+
f"Unexpected assignment. Not Supported."
176+
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
177+
)
178+
for _t in t.elts:
179+
if _t.id not in tgt_mapping:
180+
tgt_mapping[_t.id] = (_t, None) if then_or_else else (None, _t)
181+
else:
182+
v = tgt_mapping[_t.id]
183+
tgt_mapping[_t.id] = (_t, v[1]) if then_or_else else (v[0], _t)
184+
185+
d = [(v[0] or v[1]) for k, v in sorted(dict(tgt_mapping).items())]
186+
tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
187+
return tgt, tgt_mapping
188+
131189
def visit_If(self, node):
132190
# First recurse into subnodes
133191
node = self.generic_visit(node)
@@ -152,32 +210,17 @@ def visit_If(self, node):
152210
then_assigns = [n for n in node.body if isinstance(n, ast.Assign)]
153211
else_assigns = [n for n in node.orelse if isinstance(n, ast.Assign)]
154212
assert then_assigns or else_assigns, (
155-
f"Missing assignment\n--\n"
213+
f"Missing assignment"
156214
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
157215
)
158216

159-
targets = []
160-
for a in [*then_assigns, *else_assigns]:
161-
for t in a.targets:
162-
if isinstance(t, ast.Name):
163-
targets.append((t.id, t))
164-
continue
165-
166-
assert isinstance(t, ast.Tuple) and all(
167-
isinstance(_, ast.Name) for _ in t.elts
168-
), (
169-
f"Unexpected assignment. Not Supported."
170-
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
171-
)
172-
targets.extend((_.id, _) for _ in t.elts)
173-
174-
d = [_[1] for _ in sorted(dict(targets).items())]
175-
tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
217+
# the targets we need to export
218+
tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
176219

177220
then_values = [n.value for n in then_assigns]
178221
else_values = [n.value for n in else_assigns]
179222
then_def, else_def, call = self._rewrite_if(
180-
node, then_values, else_values, tgt=tgt
223+
node, then_values, else_values, tgt_mapping=tgt_mapping
181224
)
182225

183226
assign = ast.Assign(targets=[tgt], value=call)
@@ -226,6 +269,11 @@ def code(self) -> str:
226269
"""Returns the source."""
227270
return ast.unparse(self.tree)
228271

272+
@property
273+
def dump(self) -> str:
274+
"""Returns the tree dumped as a string."""
275+
return ast.dump(self.tree, indent=2)
276+
229277
def __repr__(self):
230278
"usual"
231279
return f"{self.__class__.__name__}({self.func})"
@@ -238,10 +286,49 @@ def transform_method(
238286
Returns a new function based on `func` where every test (if)
239287
is replaced by a call to :func:`torch.cond`.
240288
289+
A test must return the same things if it returns something
290+
or assign something. It cannot return in one branch and assign
291+
in the other branch.
292+
293+
.. warning:: room for improvment
294+
295+
When it assigns a value to a constant,
296+
the current implementation does check which ones is really used
297+
after the test. The rewritten local functions returns every
298+
assigned variable. This could be reduced.
299+
241300
:param func: method or function to rewrite
242301
:param if_name: function calling the test
243302
:param verbose: verbosity
244303
:return: rewritten method
304+
305+
.. runpython::
306+
:showcode:
307+
308+
import torch
309+
from onnx_diagnostic.torch_export_patches.patch_module import transform_method
310+
311+
class Model(torch.nn.Module):
312+
def forward(self, x, y):
313+
if x.sum() > 0:
314+
return x + y, x - y
315+
else:
316+
return torch.abs(x) + y, torch.abs(x) - y
317+
318+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
319+
expected = Model()(x, y)
320+
321+
rewritten = transform_method(Model.forward, verbose=10)
322+
print("-- code --")
323+
print(rewritten.code)
324+
325+
print(" -- export --")
326+
Model.forward = rewritten.func
327+
328+
DYN = torch.export.Dim.DYNAMIC
329+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
330+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
331+
print(ep)
245332
"""
246333
# Retrieve source of the function
247334
src = inspect.getsource(func)

0 commit comments

Comments
 (0)