Skip to content

Commit a7de659

Browse files
committed
fix nested
1 parent 1e05229 commit a7de659

File tree

2 files changed

+109
-20
lines changed

2 files changed

+109
-20
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,100 @@ def forward(self, x, y):
231231
self.assertEqualAny(expected, ep.module()(x, y))
232232
self.assertEqualAny(expected_, ep.module()(-x, y))
233233

234+
def test_assign_nested_check(self):
235+
236+
torch_cond = torch.cond
237+
238+
class Model(torch.nn.Module):
239+
def forward(self, x, y):
240+
def torch_cond_then_3(y, x):
241+
242+
def torch_cond_then_1(y, x):
243+
w = x + y
244+
z = x - y
245+
return (w, z)
246+
247+
def torch_cond_else_1(y, x):
248+
u = x + 10
249+
w = x + torch.abs(y) + u
250+
z = x - torch.abs(y) + u
251+
return (w, z)
252+
253+
w, z = torch_cond(
254+
y.sum() > 0, torch_cond_then_1, torch_cond_else_1, [y, x]
255+
)
256+
return (w, z)
257+
258+
def torch_cond_else_3(y, x):
259+
260+
def torch_cond_then_2(y):
261+
u = y + 1
262+
return u
263+
264+
def torch_cond_else_2(y):
265+
u = torch.abs(y) + 10
266+
return u
267+
268+
u = torch_cond(y.sum() > 0, torch_cond_then_2, torch_cond_else_2, [y])
269+
w = torch.abs(x) + u
270+
z = torch.abs(x) - u
271+
return (w, z)
272+
273+
w, z = torch_cond(x.sum() > 0, torch_cond_then_3, torch_cond_else_3, [y, x])
274+
return (w, z)
275+
276+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
277+
Model()(x, y)
278+
279+
def test_rewrite_forward_assign_nested(self):
280+
281+
class Model(torch.nn.Module):
282+
def forward(self, x, y):
283+
if x.sum() > 0:
284+
if y.sum() > 0:
285+
w = x + y
286+
z = x - y
287+
else:
288+
u = x + 10
289+
w = x + torch.abs(y) + u
290+
z = x - torch.abs(y) + u
291+
else:
292+
if y.sum() > 0:
293+
u = y + 1
294+
else:
295+
u = torch.abs(y) + 10
296+
w = torch.abs(x) + u
297+
z = torch.abs(x) - u
298+
return w, z
299+
300+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
301+
expected, expected_, expected_0, expected_1 = (
302+
Model()(x, y),
303+
Model()(-x, y),
304+
Model()(x, -y),
305+
Model()(-x, -y),
306+
)
307+
308+
rewritten = transform_method(Model.forward, verbose=self.verbose)
309+
self.assertIn("torch.abs(", rewritten.code)
310+
self.assertIn("abs", rewritten.dump)
311+
code = rewritten.code
312+
self.assertIn("branch_cond_else_3", code)
313+
Model.forward = rewritten.func
314+
self.assertEqualAny(expected, Model()(x, y))
315+
self.assertEqualAny(expected_, Model()(-x, y))
316+
self.assertEqualAny(expected_0, Model()(x, -y))
317+
self.assertEqualAny(expected_1, Model()(-x, -y))
318+
319+
DYN = torch.export.Dim.DYNAMIC
320+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
321+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
322+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
323+
self.assertEqualAny(expected, ep.module()(x, y))
324+
self.assertEqualAny(expected_, ep.module()(-x, y))
325+
self.assertEqualAny(expected_0, ep.module()(x, -y))
326+
self.assertEqualAny(expected_1, ep.module()(-x, -y))
327+
234328

235329
if __name__ == "__main__":
236330
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import types
44
import textwrap
55
from typing import Callable, Dict, List, Set, Optional
6-
import torch
76

87
NODE_TYPES = tuple(
98
getattr(ast, k)
@@ -34,22 +33,20 @@ def _settl(node, lineno, level=0):
3433

3534
class RewriteControlFlow(ast.NodeTransformer):
3635
"""
37-
The class rewrites tests with function ``torch_cond`` :func:`torch.cond`.
36+
The class rewrites tests with function :func:`torch.cond`.
3837
``empty_tensor`` is a function returning an empty tensor,
3938
when a branch returns something the other branch does not.
4039
"""
4140

4241
def __init__(
4342
self,
44-
wrapper_name,
45-
empty_tensor: str = "make_empty_tensor",
43+
prefix: str = "branch_cond",
4644
skip_objects: Optional[Dict[str, object]] = None,
4745
args_names: Optional[Set[str]] = None,
4846
):
49-
self.wrapper_name = wrapper_name
50-
self.empty_tensor = empty_tensor
5147
self.counter = 0
5248
self.current_func_args = None
49+
self.prefix = prefix
5350
self.skip_objects = skip_objects or {}
5451
self.args_names = args_names or set()
5552
self.local_variables = self.args_names.copy()
@@ -83,7 +80,7 @@ def _find_id(self, exprs: List["ast.Node"]) -> List[str]:
8380
for n in ast.walk(expr):
8481
if (
8582
isinstance(n, ast.Name)
86-
and isinstance(n.ctx, ast.Store)
83+
and isinstance(n.ctx, ast.Load)
8784
and n.id not in self.skip_objects
8885
):
8986
vars.append(n.id)
@@ -97,8 +94,8 @@ def _rewrite_if(
9794
drop = set()
9895

9996
# extract free variables
100-
then_name = f"{self.wrapper_name}_then_{self.counter}"
101-
else_name = f"{self.wrapper_name}_else_{self.counter}"
97+
then_name = f"{self.prefix}_then_{self.counter}"
98+
else_name = f"{self.prefix}_else_{self.counter}"
10299
then_vars = self._find_id(then_exprs)
103100
else_vars = self._find_id(else_exprs)
104101
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ in known_local_variables)
@@ -173,8 +170,11 @@ def _rewrite_if(
173170
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
174171
ctx=ast.Load(),
175172
)
173+
176174
call = ast.Call(
177-
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
175+
func=ast.Attribute(
176+
value=ast.Name(id="torch", ctx=ast.Load()), attr="cond", ctx=ast.Load()
177+
),
178178
args=[
179179
test_node,
180180
ast.Name(id=then_name, ctx=ast.Load()),
@@ -346,8 +346,7 @@ def inplace_add_parent(tree: "ast.Node"):
346346

347347
def transform_method(
348348
func: Callable,
349-
if_name: str = "torch_cond",
350-
empty_tensor: str = "make_empty_tensor",
349+
prefix: str = "branch_cond",
351350
verbose: int = 0,
352351
) -> RewrittenMethod:
353352
"""
@@ -367,8 +366,7 @@ def transform_method(
367366
See method ``_filter_target``.
368367
369368
:param func: method or function to rewrite
370-
:param if_name: function calling the test
371-
:param empty_tensor: function creating an empty tensor
369+
:param prefix: prefix used to create the functions for the branches
372370
:param verbose: verbosity
373371
:return: rewritten method
374372
@@ -390,7 +388,7 @@ def forward(self, x, y):
390388
x, y = torch.rand((3, 4)), torch.rand((3, 4))
391389
expected = Model()(x, y)
392390
393-
rewritten = transform_method(Model.forward, verbose=10)
391+
rewritten = transform_method(Model.forward)
394392
print("-- code --")
395393
print(rewritten.code)
396394
@@ -423,7 +421,7 @@ def forward(self, x, y):
423421
x, y = torch.rand((3, 4)), torch.rand((3, 4))
424422
expected = Model()(x, y)
425423
426-
rewritten = transform_method(Model.forward, verbose=10)
424+
rewritten = transform_method(Model.forward)
427425
print("-- code --")
428426
print(rewritten.code)
429427
@@ -447,8 +445,7 @@ def forward(self, x, y):
447445
print(f"[transform_method] -- tree --\n\n{ast.dump(tree, indent=2)}")
448446
# Apply transformation
449447
transformer = RewriteControlFlow(
450-
if_name,
451-
empty_tensor=empty_tensor,
448+
prefix=prefix,
452449
skip_objects=modules,
453450
args_names=set(sig.parameters),
454451
)
@@ -482,8 +479,6 @@ def forward(self, x, y):
482479
) from e
483480
namespace: Dict[str, type] = {}
484481
globs = func.__globals__.copy()
485-
globs[if_name] = torch.cond
486-
globs[empty_tensor] = lambda: torch.tensor([])
487482
exec(mod, globs, namespace)
488483
new_func = namespace.get(func.__name__)
489484
if not isinstance(new_func, types.FunctionType):

0 commit comments

Comments
 (0)