Skip to content

Commit d7c15ed

Browse files
committed
fix a few things
1 parent 105cbe6 commit d7c15ed

File tree

2 files changed

+17
-30
lines changed

2 files changed

+17
-30
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def forward(self, x, y):
112112
self.assertEqualAny(expected, ep.module()(x, y))
113113
self.assertEqualAny(expected_, ep.module()(-x, y))
114114

115-
def test_rewrite_forward_noelse(self):
115+
def test_rewrite_forward_assign_noelse(self):
116116

117117
class Model(torch.nn.Module):
118118
def forward(self, x, y):

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -67,42 +67,31 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
6767
then_vars = self._find_id(then_exprs)
6868
else_vars = self._find_id(else_exprs)
6969
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
70-
then_expr, else_expr = None, None
70+
then_ret, else_ret = None, None
7171
if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
7272
# return
73-
then_expr = then_exprs[0]
74-
else_expr = else_exprs[0]
75-
elif (
76-
tgt_mapping
77-
and len(then_exprs) == 1
78-
and len(else_exprs) == 1
79-
and len(tgt_mapping) == 1
80-
):
81-
# assignment but only one
82-
then_expr = then_exprs[0]
83-
else_expr = else_exprs[0]
73+
then_exprs = [n for n in node.body if not isinstance(n, ast.Return)]
74+
else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
75+
then_ret = then_exprs[0]
76+
else_ret = else_exprs[0]
8477
else:
8578
assert tgt_mapping, (
8679
f"then and else branchs do not have the same number "
8780
f"of assignments, we need more information to understand "
8881
f"which ones to return,"
8982
f"\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
9083
)
91-
then_exprs = []
92-
else_exprs = []
84+
then_exprs, else_exprs = node.body, node.orelse
85+
then_rets, else_rets = [], []
9386
for t in tgt_mapping:
9487
then_e, else_e = tgt_mapping[t]
95-
then_exprs.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
96-
else_exprs.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
97-
then_expr = (
98-
then_exprs[0]
99-
if len(then_exprs) == 1
100-
else ast.Tuple(then_exprs, ctx=ast.Load())
88+
then_rets.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
89+
else_rets.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
90+
then_ret = (
91+
then_rets[0] if len(then_rets) == 1 else ast.Tuple(then_rets, ctx=ast.Load())
10192
)
102-
else_expr = (
103-
else_exprs[0]
104-
if len(else_exprs) == 1
105-
else ast.Tuple(else_exprs, ctx=ast.Load())
93+
else_ret = (
94+
else_rets[0] if len(else_rets) == 1 else ast.Tuple(else_rets, ctx=ast.Load())
10695
)
10796

10897
# build local funcs
@@ -115,7 +104,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
115104
kw_defaults=[],
116105
defaults=[],
117106
),
118-
body=[ast.Return(then_expr)],
107+
body=[*then_exprs, ast.Return(then_ret)],
119108
decorator_list=[],
120109
returns=None,
121110
)
@@ -128,7 +117,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
128117
kw_defaults=[],
129118
defaults=[],
130119
),
131-
body=[ast.Return(else_expr)],
120+
body=[*else_exprs, ast.Return(else_ret)],
132121
decorator_list=[],
133122
returns=None,
134123
)
@@ -217,10 +206,8 @@ def visit_If(self, node):
217206
# the targets we need to export
218207
tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
219208

220-
then_values = [n.value for n in then_assigns]
221-
else_values = [n.value for n in else_assigns]
222209
then_def, else_def, call = self._rewrite_if(
223-
node, then_values, else_values, tgt_mapping=tgt_mapping
210+
node, then_assigns, else_assigns, tgt_mapping=tgt_mapping
224211
)
225212

226213
assign = ast.Assign(targets=[tgt], value=call)

0 commit comments

Comments
 (0)