@@ -67,42 +67,31 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
6767 then_vars = self ._find_id (then_exprs )
6868 else_vars = self ._find_id (else_exprs )
6969 then_else_vars = set (_ for _ in [* then_vars , * else_vars ] if _ != "torch" )
70- then_expr , else_expr = None , None
70+ then_ret , else_ret = None , None
7171 if tgt_mapping is None and len (then_exprs ) == 1 and len (else_exprs ) == 1 :
7272 # return
73- then_expr = then_exprs [0 ]
74- else_expr = else_exprs [0 ]
75- elif (
76- tgt_mapping
77- and len (then_exprs ) == 1
78- and len (else_exprs ) == 1
79- and len (tgt_mapping ) == 1
80- ):
81- # assignment but only one
82- then_expr = then_exprs [0 ]
83- else_expr = else_exprs [0 ]
73+ then_exprs = [n for n in node .body if not isinstance (n , ast .Return )]
74+ else_exprs = [n for n in node .orelse if not isinstance (n , ast .Return )]
75+ then_ret = then_exprs [0 ]
76+ else_ret = else_exprs [0 ]
8477 else :
8578 assert tgt_mapping , (
8679 f"then and else branchs do not have the same number "
8780 f"of assignments, we need more information to understand "
8881 f"which ones to return,"
8982 f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
9083 )
91- then_exprs = []
92- else_exprs = []
84+ then_exprs , else_exprs = node . body , node . orelse
85+ then_rets , else_rets = [], []
9386 for t in tgt_mapping :
9487 then_e , else_e = tgt_mapping [t ]
95- then_exprs .append (then_e or ast .Name (else_e .id , ctx = ast .Load ()))
96- else_exprs .append (else_e or ast .Name (then_e .id , ctx = ast .Load ()))
97- then_expr = (
98- then_exprs [0 ]
99- if len (then_exprs ) == 1
100- else ast .Tuple (then_exprs , ctx = ast .Load ())
88+ then_rets .append (then_e or ast .Name (else_e .id , ctx = ast .Load ()))
89+ else_rets .append (else_e or ast .Name (then_e .id , ctx = ast .Load ()))
90+ then_ret = (
91+ then_rets [0 ] if len (then_rets ) == 1 else ast .Tuple (then_rets , ctx = ast .Load ())
10192 )
102- else_expr = (
103- else_exprs [0 ]
104- if len (else_exprs ) == 1
105- else ast .Tuple (else_exprs , ctx = ast .Load ())
93+ else_ret = (
94+ else_rets [0 ] if len (else_rets ) == 1 else ast .Tuple (else_rets , ctx = ast .Load ())
10695 )
10796
10897 # build local funcs
@@ -115,7 +104,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
115104 kw_defaults = [],
116105 defaults = [],
117106 ),
118- body = [ast .Return (then_expr )],
107+ body = [* then_exprs , ast .Return (then_ret )],
119108 decorator_list = [],
120109 returns = None ,
121110 )
@@ -128,7 +117,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
128117 kw_defaults = [],
129118 defaults = [],
130119 ),
131- body = [ast .Return (else_expr )],
120+ body = [* else_exprs , ast .Return (else_ret )],
132121 decorator_list = [],
133122 returns = None ,
134123 )
@@ -217,10 +206,8 @@ def visit_If(self, node):
217206 # the targets we need to export
218207 tgt , tgt_mapping = self ._make_targets (node , then_assigns , else_assigns )
219208
220- then_values = [n .value for n in then_assigns ]
221- else_values = [n .value for n in else_assigns ]
222209 then_def , else_def , call = self ._rewrite_if (
223- node , then_values , else_values , tgt_mapping = tgt_mapping
210+ node , then_assigns , else_assigns , tgt_mapping = tgt_mapping
224211 )
225212
226213 assign = ast .Assign (targets = [tgt ], value = call )
0 commit comments