@@ -46,36 +46,45 @@ def visit_FunctionDef(self, node):
4646 self .current_func_args = old_args
4747 return node
4848
49- def _rewrite_if (self , node , then_expr , else_expr ):
49+ def _find_id (self , exprs ):
50+ vars = []
51+ for expr in exprs :
52+ for n in ast .walk (expr ):
53+ if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load ):
54+ vars .append (n .id )
55+ return sorted (set (vars ))
56+
57+ def _rewrite_if (self , node , then_exprs , else_exprs , tgt = None ):
5058 test_node = node .test
5159
5260 # extract free variables
5361 then_name = f"{ self .wrapper_name } _then_{ self .counter } "
5462 else_name = f"{ self .wrapper_name } _else_{ self .counter } "
55- then_vars = sorted (
56- {
57- n .id
58- for n in ast .walk (then_expr )
59- if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
60- }
61- )
62- else_vars = sorted (
63- {
64- n .id
65- for n in ast .walk (else_expr )
66- if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
67- }
68- )
69-
63+ then_vars = self ._find_id (then_exprs )
64+ else_vars = self ._find_id (else_exprs )
7065 then_else_vars = set (_ for _ in [* then_vars , * else_vars ] if _ != "torch" )
66+ then_expr , else_expr = None , None
67+ if tgt is None and len (then_exprs ) == 1 and len (else_exprs ) == 1 :
68+ # return
69+ then_expr = then_exprs [0 ]
70+ else_expr = else_exprs [0 ]
71+ elif len (then_exprs ) == 1 and len (else_exprs ) == 1 :
72+ # assignment but only one, so we assume it is the same
73+ then_expr = then_exprs [0 ]
74+ else_expr = else_exprs [0 ]
75+ else :
76+ raise NotImplementedError (
77+ f"Unable to rewrite node { node } , len(then_exprs)={ len (then_exprs )} , "
78+ f"len(else_exprs)={ len (else_exprs )} , "
79+ f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
80+ )
7181
7282 # build local funcs
73- then_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
7483 then_def = ast .FunctionDef (
7584 name = then_name ,
7685 args = ast .arguments (
7786 posonlyargs = [],
78- args = then_args ,
87+ args = [ ast . arg ( arg = v , annotation = None ) for v in then_else_vars ] ,
7988 kwonlyargs = [],
8089 kw_defaults = [],
8190 defaults = [],
@@ -84,12 +93,11 @@ def _rewrite_if(self, node, then_expr, else_expr):
8493 decorator_list = [],
8594 returns = None ,
8695 )
87- else_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
8896 else_def = ast .FunctionDef (
8997 name = else_name ,
9098 args = ast .arguments (
9199 posonlyargs = [],
92- args = else_args ,
100+ args = [ ast . arg ( arg = v , annotation = None ) for v in then_else_vars ] ,
93101 kwonlyargs = [],
94102 kw_defaults = [],
95103 defaults = [],
@@ -124,50 +132,77 @@ def visit_If(self, node):
124132 # First recurse into subnodes
125133 node = self .generic_visit (node )
126134
127- # Case 1: simple assignment in both branches
128- if (
129- len (node .body ) == 1
130- and isinstance (node .body [0 ], ast .Assign )
131- and len (node .orelse ) == 1
132- and isinstance (node .orelse [0 ], ast .Assign )
133- and self .current_func_args is not None
134- ):
135- then_assign = node .body [0 ]
136- else_assign = node .orelse [0 ]
137- tgt = then_assign .targets [0 ]
138- if (
139- isinstance (tgt , ast .Name )
140- and isinstance (else_assign .targets [0 ], ast .Name )
141- and tgt .id == else_assign .targets [0 ].id
142- ):
143- self .counter += 1
144- then_expr = then_assign .value
145- else_expr = else_assign .value
146- then_def , else_def , call = self ._rewrite_if (node , then_expr , else_expr )
147- assign = ast .Assign (targets = [tgt ], value = call )
148- ast .copy_location (assign , node )
149- ast .fix_missing_locations (assign )
150- return [then_def , else_def , assign ]
135+ has_then_return = any (isinstance (n , ast .Return ) for n in node .body )
136+ has_else_return = any (isinstance (n , ast .Return ) for n in node .orelse )
137+ ok = (has_then_return and has_else_return ) or (
138+ not has_then_return and not has_else_return
139+ )
140+ assert ok , (
141+ f"Cannot mix return and assignment in a test\n --\n "
142+ f"{ ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
143+ )
144+ assert self .current_func_args is not None , (
145+ f"current_func_args is None\n --\n "
146+ f"{ ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
147+ )
148+ self .counter += 1
151149
152- # Case 2: simple return in both branches
153- if (
154- len (node .body ) == 1
155- and isinstance (node .body [0 ], ast .Return )
156- and len (node .orelse ) == 1
157- and isinstance (node .orelse [0 ], ast .Return )
158- and self .current_func_args is not None
159- ):
160- then_ret = node .body [0 ]
161- else_ret = node .orelse [0 ]
162- then_expr = then_ret .value
163- else_expr = else_ret .value
164- self .counter += 1
165- then_def , else_def , call = self ._rewrite_if (node , then_expr , else_expr )
166- ret = ast .Return (call )
167- ast .copy_location (ret , node )
168- ast .fix_missing_locations (ret )
169- return [then_def , else_def , ret ]
170- return node
150+ if not has_then_return :
151+ # Case 1: simple assignment in both branches
152+ then_assigns = [n for n in node .body if isinstance (n , ast .Assign )]
153+ else_assigns = [n for n in node .orelse if isinstance (n , ast .Assign )]
154+ assert then_assigns or else_assigns , (
155+ f"Missing assignment\n --\n "
156+ f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
157+ )
158+
159+ targets = []
160+ for a in [* then_assigns , * else_assigns ]:
161+ for t in a .targets :
162+ if isinstance (t , ast .Name ):
163+ targets .append ((t .id , t ))
164+ continue
165+
166+ assert isinstance (t , ast .Tuple ) and all (
167+ isinstance (_ , ast .Name ) for _ in t .elts
168+ ), (
169+ f"Unexpected assignment. Not Supported."
170+ f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
171+ )
172+ targets .extend ((_ .id , _ ) for _ in t .elts )
173+
174+ d = [_ [1 ] for _ in sorted (dict (targets ).items ())]
175+ tgt = d [0 ] if len (d ) == 1 else ast .Tuple (d , ctx = ast .Load ())
176+
177+ then_values = [n .value for n in then_assigns ]
178+ else_values = [n .value for n in else_assigns ]
179+ then_def , else_def , call = self ._rewrite_if (
180+ node , then_values , else_values , tgt = tgt
181+ )
182+
183+ assign = ast .Assign (targets = [tgt ], value = call )
184+ ast .copy_location (assign , node )
185+ ast .fix_missing_locations (assign )
186+ return [then_def , else_def , assign ]
187+
188+ # Case 2: return in both branches, we assume both branches return the same results.
189+ then_ret = node .body [- 1 ]
190+ else_ret = node .orelse [- 1 ]
191+ assert isinstance (then_ret , ast .Return ), (
192+ f"return is not the last instruction of then branch"
193+ f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
194+ )
195+ assert isinstance (else_ret , ast .Return ), (
196+ f"return is not the last instruction of else branch"
197+ f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
198+ )
199+ then_expr = then_ret .value
200+ else_expr = else_ret .value
201+ then_def , else_def , call = self ._rewrite_if (node , [then_expr ], [else_expr ])
202+ ret = ast .Return (call )
203+ ast .copy_location (ret , node )
204+ ast .fix_missing_locations (ret )
205+ return [then_def , else_def , ret ]
171206
172207 def generic_visit (self , node ):
173208 return super ().generic_visit (node )
@@ -211,26 +246,24 @@ def transform_method(
211246 # Retrieve source of the function
212247 src = inspect .getsource (func )
213248 if verbose :
214- print (f"[transform_method] -- source -- { func } " )
215- print (src )
249+ print (f"[transform_method] -- source -- { func } \n \n { src } \n \n [transform_method] --" )
216250 # Parse into AST
217251 tree = ast .parse (textwrap .dedent (src ))
218252 if verbose > 1 :
219- print ("[transform_method] -- tree --" )
220- print (ast .dump (tree , indent = 2 ))
253+ print (f"[transform_method] -- tree --\n \n { ast .dump (tree , indent = 2 )} " )
221254 # Apply transformation
222255 transformer = RewriteControlFlow (if_name )
223256 new_tree = transformer .visit (tree )
224257 if verbose > 1 :
225- print ("[transform_method] -- new tree --" )
226- print (ast .dump (tree , indent = 2 ))
258+ print (f"[transform_method] -- new tree --\n \n { ast .dump (tree , indent = 2 )} " )
227259 ast .fix_missing_locations (new_tree )
228260 _settl (new_tree , 0 )
229261
230262 if verbose > 0 :
231- print ("[transform_method] -- new code --" )
232- code = ast .unparse (new_tree )
233- print (code )
263+ print (
264+ f"[transform_method] -- new code --\n \n "
265+ f"{ ast .unparse (new_tree )} \n \n [transform_method] --"
266+ )
234267 try :
235268 mod = compile (new_tree , filename = "<ast>" , mode = "exec" )
236269 except TypeError as e :
0 commit comments