@@ -35,10 +35,13 @@ def _settl(node, lineno, level=0):
3535class RewriteControlFlow (ast .NodeTransformer ):
3636 """
3737 The class rewrites tests with function ``torch_cond`` :func:`torch.cond`.
38+ ``empty_tensor`` is a function returning an empty tensor,
39+ when a branch returns something the other branch does not.
3840 """
3941
40- def __init__ (self , wrapper_name ):
42+ def __init__ (self , wrapper_name , empty_tensor : str = "make_empty_tensor" ):
4143 self .wrapper_name = wrapper_name
44+ self .empty_tensor = empty_tensor
4245 self .counter = 0
4346 self .current_func_args = None
4447
@@ -60,6 +63,7 @@ def _find_id(self, exprs):
6063
6164 def _rewrite_if (self , node , then_exprs , else_exprs , tgt_mapping = None ):
6265 test_node = node .test
66+ drop = set ()
6367
6468 # extract free variables
6569 then_name = f"{ self .wrapper_name } _then_{ self .counter } "
@@ -69,7 +73,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
6973 then_else_vars = set (
7074 _
7175 for _ in [* then_vars , * else_vars ]
72- if _ != "torch" and (not tgt_mapping or _ not in tgt_mapping )
76+ if _ != "torch" # and (not tgt_mapping or _ not in tgt_mapping)
7377 )
7478 then_ret , else_ret = None , None
7579 if tgt_mapping is None and len (then_exprs ) == 1 and len (else_exprs ) == 1 :
@@ -80,15 +84,21 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
8084 else_exprs = [n for n in node .orelse if not isinstance (n , ast .Return )]
8185 else :
8286 assert tgt_mapping , (
83- f"then and else branchs do not have the same number "
87+ f"then and else branches do not have the same number "
8488 f"of assignments, we need more information to understand "
8589 f"which ones to return,"
8690 f"\n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
8791 )
92+ drop = set ()
8893 then_exprs , else_exprs = node .body , node .orelse
8994 then_rets , else_rets = [], []
90- for t in tgt_mapping :
91- then_e , else_e = tgt_mapping [t ]
95+ for t , then_else in sorted (tgt_mapping .items ()):
96+ then_e , else_e = then_else
97+ if (then_e is None or else_e is None ) and t not in then_else_vars :
98+ # The variable is not used by one branch and it is not an input.
99+ # Let's drop it.
100+ drop .add (t )
101+ continue
92102 then_rets .append (then_e or ast .Name (else_e .id , ctx = ast .Load ()))
93103 else_rets .append (else_e or ast .Name (then_e .id , ctx = ast .Load ()))
94104 then_ret = (
@@ -145,7 +155,7 @@ def _rewrite_if(self, node, then_exprs, else_exprs, tgt_mapping=None):
145155 ],
146156 keywords = [],
147157 )
148- return then_def , else_def , call
158+ return then_def , else_def , call , drop
149159
150160 def _filter_target (self , node , tgt_mapping ):
151161 """
@@ -220,9 +230,13 @@ def visit_If(self, node):
220230 # the targets we need to export
221231 tgt , tgt_mapping = self ._make_targets (node , then_assigns , else_assigns )
222232
223- then_def , else_def , call = self ._rewrite_if (
233+ then_def , else_def , call , dropped = self ._rewrite_if (
224234 node , then_assigns , else_assigns , tgt_mapping = tgt_mapping
225235 )
236+ if dropped and isinstance (tgt , ast .Tuple ):
237+ tgt = ast .Tuple (
238+ tuple (t for t in tgt .elts if t .id not in dropped ), ctx = ast .Load ()
239+ )
226240
227241 assign = ast .Assign (targets = [tgt ], value = call )
228242 ast .copy_location (assign , node )
@@ -242,7 +256,7 @@ def visit_If(self, node):
242256 )
243257 then_expr = then_ret .value
244258 else_expr = else_ret .value
245- then_def , else_def , call = self ._rewrite_if (node , [then_expr ], [else_expr ])
259+ then_def , else_def , call , dropped = self ._rewrite_if (node , [then_expr ], [else_expr ])
246260 ret = ast .Return (call )
247261 ast .copy_location (ret , node )
248262 ast .fix_missing_locations (ret )
@@ -281,7 +295,10 @@ def __repr__(self):
281295
282296
283297def transform_method (
284- func : Callable , if_name = "torch_cond" , verbose : int = 0
298+ func : Callable ,
299+ if_name : str = "torch_cond" ,
300+ empty_tensor : str = "make_empty_tensor" ,
301+ verbose : int = 0 ,
285302) -> RewrittenMethod :
286303 """
287304 Returns a new function based on `func` where every test (if)
@@ -291,7 +308,7 @@ def transform_method(
291308 or assign something. It cannot return in one branch and assign
292309 in the other branch.
293310
294- .. warning:: room for improvment
311+ .. warning:: room for improvement
295312
296313 When it assigns a value to a constant,
297314 the current implementation does check which ones is really used
@@ -301,6 +318,7 @@ def transform_method(
301318
302319 :param func: method or function to rewrite
303320 :param if_name: function calling the test
321+ :param empty_tensor: function creating an empty tensor
304322 :param verbose: verbosity
305323 :return: rewritten method
306324
@@ -341,7 +359,7 @@ def forward(self, x, y):
341359 if verbose > 1 :
342360 print (f"[transform_method] -- tree --\n \n { ast .dump (tree , indent = 2 )} " )
343361 # Apply transformation
344- transformer = RewriteControlFlow (if_name )
362+ transformer = RewriteControlFlow (if_name , empty_tensor = empty_tensor )
345363 new_tree = transformer .visit (tree )
346364 if verbose > 1 :
347365 print (f"[transform_method] -- new tree --\n \n { ast .dump (tree , indent = 2 )} " )
@@ -372,6 +390,7 @@ def forward(self, x, y):
372390 namespace : Dict [str , type ] = {}
373391 globs = func .__globals__ .copy ()
374392 globs [if_name ] = torch .cond
393+ globs [empty_tensor ] = lambda : torch .tensor ([])
375394 exec (mod , globs , namespace )
376395 new_func = namespace .get (func .__name__ )
377396 if not isinstance (new_func , types .FunctionType ):
0 commit comments