Skip to content

Commit d303de4

Browse files
committed
order
1 parent 904bfb7 commit d303de4

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,81 @@ def forward(self, x, y):
137137
self.assertEqualAny(expected, ep.module()(x, y))
138138
self.assertEqualAny(expected_, ep.module()(-x, y))
139139

140+
def test_rewrite_forward_return_noelse(self):
141+
142+
class Model(torch.nn.Module):
143+
def forward(self, x, y):
144+
if x.sum() > 0:
145+
return torch.abs(x) + 1 + y
146+
return x + y
147+
148+
self.assertRaise(
149+
lambda: transform_method(Model.forward, verbose=0), NotImplementedError
150+
)
151+
152+
def test_rewrite_forward_assign2_in_2(self):
153+
154+
class Model(torch.nn.Module):
155+
def forward(self, x, y):
156+
if x.sum() > 0:
157+
w = x + y
158+
z = x - y
159+
else:
160+
w = torch.abs(x) + y + 1
161+
z = torch.abs(x) - y + 1
162+
return w, z
163+
164+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
165+
expected, expected_ = Model()(x, y), Model()(-x, y)
166+
167+
rewritten = transform_method(Model.forward, verbose=0)
168+
self.assertIn("torch.abs(", rewritten.code)
169+
self.assertIn("abs", rewritten.dump)
170+
Model.forward = rewritten.func
171+
self.assertEqualAny(expected, Model()(x, y))
172+
self.assertEqualAny(expected_, Model()(-x, y))
173+
174+
DYN = torch.export.Dim.DYNAMIC
175+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
176+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
177+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
178+
self.assertEqualAny(expected, ep.module()(x, y))
179+
self.assertEqualAny(expected_, ep.module()(-x, y))
180+
181+
def test_rewrite_forward_assign2_in_3(self):
182+
183+
class Model(torch.nn.Module):
184+
def forward(self, x, y):
185+
if x.sum() > 0:
186+
w = x + y
187+
z = x - y
188+
else:
189+
u = y + 1
190+
w = torch.abs(x) + u
191+
z = torch.abs(x) - u
192+
return w, z
193+
194+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
195+
expected, expected_ = Model()(x, y), Model()(-x, y)
196+
197+
rewritten = transform_method(Model.forward, verbose=0)
198+
self.assertIn("torch.abs(", rewritten.code)
199+
self.assertIn("abs", rewritten.dump)
200+
code = rewritten.code
201+
assert ("w, z, u" in code and "u, w, z" not in code) or (
202+
"w, z, u" not in code and "u, w, z" in code
203+
), f"Order mismatch in\n{code}"
204+
Model.forward = rewritten.func
205+
self.assertEqualAny(expected, Model()(x, y))
206+
self.assertEqualAny(expected_, Model()(-x, y))
207+
208+
DYN = torch.export.Dim.DYNAMIC
209+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
210+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
211+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
212+
self.assertEqualAny(expected, ep.module()(x, y))
213+
self.assertEqualAny(expected_, ep.module()(-x, y))
214+
140215

141216
if __name__ == "__main__":
142217
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
6666
else_name = f"{self.wrapper_name}_else_{self.counter}"
6767
then_vars = self._find_id(then_exprs)
6868
else_vars = self._find_id(else_exprs)
69-
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
69+
then_else_vars = set(
70+
_
71+
for _ in [*then_vars, *else_vars]
72+
if _ != "torch" and (not tgt_mapping or _ not in tgt_mapping)
73+
)
7074
then_ret, else_ret = None, None
7175
if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
7276
# return
@@ -143,6 +147,13 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
143147
)
144148
return then_def, else_def, call
145149

150+
def _filter_target(self, node, tgt_mapping):
151+
"""
152+
This function should reduce the number of elements to return
153+
by looking at the one used after the If statement.
154+
"""
155+
return tgt_mapping
156+
146157
def _make_targets(self, node, then_assigns, else_assigns):
147158
tgt_mapping = {}
148159
for a, then_or_else in [
@@ -171,6 +182,7 @@ def _make_targets(self, node, then_assigns, else_assigns):
171182
v = tgt_mapping[_t.id]
172183
tgt_mapping[_t.id] = (_t, v[1]) if then_or_else else (v[0], _t)
173184

185+
tgt_mapping = self._filter_target(node, tgt_mapping)
174186
d = [(v[0] or v[1]) for k, v in sorted(dict(tgt_mapping).items())]
175187
tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
176188
return tgt, tgt_mapping
@@ -184,10 +196,12 @@ def visit_If(self, node):
184196
ok = (has_then_return and has_else_return) or (
185197
not has_then_return and not has_else_return
186198
)
187-
assert ok, (
188-
f"Cannot mix return and assignment in a test\n--\n"
189-
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
190-
)
199+
if not ok:
200+
raise NotImplementedError(
201+
f"Cannot mix return and assignment in a test or a "
202+
f"unique then branch with a return\n--\n"
203+
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
204+
)
191205
assert self.current_func_args is not None, (
192206
f"current_func_args is None\n--\n"
193207
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
@@ -283,6 +297,7 @@ def transform_method(
283297
the current implementation does check which ones is really used
284298
after the test. The rewritten local functions returns every
285299
assigned variable. This could be reduced.
300+
See method ``_filter_target``.
286301
287302
:param func: method or function to rewrite
288303
:param if_name: function calling the test

0 commit comments

Comments
 (0)