Skip to content

Commit ea5170b

Browse files
committed
almost
1 parent d303de4 commit ea5170b

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,11 @@ def forward(self, x, y):
194194
x, y = torch.rand((3, 4)), torch.rand((3, 4))
195195
expected, expected_ = Model()(x, y), Model()(-x, y)
196196

197-
rewritten = transform_method(Model.forward, verbose=0)
197+
rewritten = transform_method(Model.forward, verbose=1)
198198
self.assertIn("torch.abs(", rewritten.code)
199199
self.assertIn("abs", rewritten.dump)
200200
code = rewritten.code
201-
assert ("w, z, u" in code and "u, w, z" not in code) or (
202-
"w, z, u" not in code and "u, w, z" in code
203-
), f"Order mismatch in\n{code}"
201+
assert "w, z, u" not in code and "u, w, z" not in code, f"None dropped\n{code}"
204202
Model.forward = rewritten.func
205203
self.assertEqualAny(expected, Model()(x, y))
206204
self.assertEqualAny(expected_, Model()(-x, y))

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def _settl(node, lineno, level=0):
3535
class RewriteControlFlow(ast.NodeTransformer):
3636
"""
3737
The class rewrites tests with function ``torch_cond`` :func:`torch.cond`.
38+
``empty_tensor`` is a function returning an empty tensor,
39+
when a branch returns something the other branch does not.
3840
"""
3941

40-
def __init__(self, wrapper_name):
42+
def __init__(self, wrapper_name, empty_tensor: str = "make_empty_tensor"):
4143
self.wrapper_name = wrapper_name
44+
self.empty_tensor = empty_tensor
4245
self.counter = 0
4346
self.current_func_args = None
4447

@@ -60,6 +63,7 @@ def _find_id(self, exprs):
6063

6164
def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
6265
test_node = node.test
66+
drop = set()
6367

6468
# extract free variables
6569
then_name = f"{self.wrapper_name}_then_{self.counter}"
@@ -69,7 +73,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
6973
then_else_vars = set(
7074
_
7175
for _ in [*then_vars, *else_vars]
72-
if _ != "torch" and (not tgt_mapping or _ not in tgt_mapping)
76+
if _ != "torch" # and (not tgt_mapping or _ not in tgt_mapping)
7377
)
7478
then_ret, else_ret = None, None
7579
if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
@@ -80,15 +84,21 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
8084
else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
8185
else:
8286
assert tgt_mapping, (
83-
f"then and else branchs do not have the same number "
87+
f"then and else branches do not have the same number "
8488
f"of assignments, we need more information to understand "
8589
f"which ones to return,"
8690
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
8791
)
92+
drop = set()
8893
then_exprs, else_exprs = node.body, node.orelse
8994
then_rets, else_rets = [], []
90-
for t in tgt_mapping:
91-
then_e, else_e = tgt_mapping[t]
95+
for t, then_else in sorted(tgt_mapping.items()):
96+
then_e, else_e = then_else
97+
if (then_e is None or else_e is None) and t not in then_else_vars:
98+
# The variable is not used by one branch and it is not an input.
99+
# Let's drop it.
100+
drop.add(t)
101+
continue
92102
then_rets.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
93103
else_rets.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
94104
then_ret = (
@@ -145,7 +155,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
145155
],
146156
keywords=[],
147157
)
148-
return then_def, else_def, call
158+
return then_def, else_def, call, drop
149159

150160
def _filter_target(self, node, tgt_mapping):
151161
"""
@@ -220,9 +230,13 @@ def visit_If(self, node):
220230
# the targets we need to export
221231
tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
222232

223-
then_def, else_def, call = self._rewrite_if(
233+
then_def, else_def, call, dropped = self._rewrite_if(
224234
node, then_assigns, else_assigns, tgt_mapping=tgt_mapping
225235
)
236+
if dropped and isinstance(tgt, ast.Tuple):
237+
tgt = ast.Tuple(
238+
tuple(t for t in tgt.elts if t.id not in dropped), ctx=ast.Load()
239+
)
226240

227241
assign = ast.Assign(targets=[tgt], value=call)
228242
ast.copy_location(assign, node)
@@ -242,7 +256,7 @@ def visit_If(self, node):
242256
)
243257
then_expr = then_ret.value
244258
else_expr = else_ret.value
245-
then_def, else_def, call = self._rewrite_if(node, [then_expr], [else_expr])
259+
then_def, else_def, call, dropped = self._rewrite_if(node, [then_expr], [else_expr])
246260
ret = ast.Return(call)
247261
ast.copy_location(ret, node)
248262
ast.fix_missing_locations(ret)
@@ -281,7 +295,10 @@ def __repr__(self):
281295

282296

283297
def transform_method(
284-
func: Callable, if_name="torch_cond", verbose: int = 0
298+
func: Callable,
299+
if_name: str = "torch_cond",
300+
empty_tensor: str = "make_empty_tensor",
301+
verbose: int = 0,
285302
) -> RewrittenMethod:
286303
"""
287304
Returns a new function based on `func` where every test (if)
@@ -291,7 +308,7 @@ def transform_method(
291308
or assign something. It cannot return in one branch and assign
292309
in the other branch.
293310
294-
.. warning:: room for improvment
311+
.. warning:: room for improvement
295312
296313
When it assigns a value to a constant,
297314
the current implementation does check which ones is really used
@@ -301,6 +318,7 @@ def transform_method(
301318
302319
:param func: method or function to rewrite
303320
:param if_name: function calling the test
321+
:param empty_tensor: function creating an empty tensor
304322
:param verbose: verbosity
305323
:return: rewritten method
306324
@@ -341,7 +359,7 @@ def forward(self, x, y):
341359
if verbose > 1:
342360
print(f"[transform_method] -- tree --\n\n{ast.dump(tree, indent=2)}")
343361
# Apply transformation
344-
transformer = RewriteControlFlow(if_name)
362+
transformer = RewriteControlFlow(if_name, empty_tensor=empty_tensor)
345363
new_tree = transformer.visit(tree)
346364
if verbose > 1:
347365
print(f"[transform_method] -- new tree --\n\n{ast.dump(tree, indent=2)}")
@@ -372,6 +390,7 @@ def forward(self, x, y):
372390
namespace: Dict[str, type] = {}
373391
globs = func.__globals__.copy()
374392
globs[if_name] = torch.cond
393+
globs[empty_tensor] = lambda: torch.tensor([])
375394
exec(mod, globs, namespace)
376395
new_func = namespace.get(func.__name__)
377396
if not isinstance(new_func, types.FunctionType):

0 commit comments

Comments
 (0)