@@ -177,6 +177,18 @@ def _rewrite_if(
177177 else_ret = else_exprs [0 ]
178178 then_exprs = [n for n in node .body if not isinstance (n , ast .Return )]
179179 else_exprs = [n for n in node .orelse if not isinstance (n , ast .Return )]
180+ assert type (then_ret .value ) is type (else_ret .value ), (
181+ f"Inconsistencies return then value={ then_ret .value } , "
182+ f"else value={ else_ret .value } "
183+ )
184+ if isinstance (then_ret .value , (ast .Tuple , ast .list )):
185+ assert len (then_ret .value .elts ) == len (else_ret .value .elts ), (
186+ f"Unexpected number of elements on both branches, "
187+ f"then:{ then_ret .value .elts } , else:{ else_ret .value .elts } "
188+ )
189+ n_returned_values = len (then_ret .value .elts )
190+ else :
191+ n_returned_values = 0
180192 else :
181193 self ._check (
182194 tgt_mapping ,
@@ -207,6 +219,7 @@ def _rewrite_if(
207219 if len (else_rets ) == 1
208220 else ast .Tuple ([self ._clone (r ) for r in else_rets ], ctx = ast .Load ())
209221 )
222+ n_returned_values = len (then_rets ) if len (then_rets ) > 1 else 0
210223
211224 # build local funcs
212225 then_def = ast .FunctionDef (
@@ -258,7 +271,7 @@ def _rewrite_if(
258271 ],
259272 keywords = [],
260273 )
261- return then_def , else_def , call , drop
274+ return then_def , else_def , call , drop , n_returned_values
262275
263276 def _filter_target (self , node , tgt_mapping ):
264277 """
@@ -330,17 +343,32 @@ def visit_If(self, node):
330343 # the targets we need to export
331344 tgt , tgt_mapping = self ._make_targets (node , then_assigns , else_assigns )
332345
333- then_def , else_def , call , dropped = self ._rewrite_if (
346+ then_def , else_def , call , dropped , n_returned_values = self ._rewrite_if (
334347 node ,
335348 then_assigns ,
336349 else_assigns ,
337350 tgt_mapping = tgt_mapping ,
338351 known_local_variables = known_local_variables ,
339352 )
340353 if dropped and isinstance (tgt , ast .Tuple ):
341- tgt = ast .Tuple (
342- tuple (t for t in tgt .elts if t .id not in dropped ), ctx = ast .Store ()
354+ tgt_elts = tuple (t for t in tgt .elts if t .id not in dropped )
355+ elif isinstance (tgt , ast .Tuple ):
356+ tgt_elts = tuple (t for t in tgt .elts if t .id not in dropped )
357+ else :
358+ tgt_elts = [tgt ]
359+
360+ if n_returned_values == 0 :
361+ assert len (tgt_elts ) == 1 , (
362+ f"Inconsistencies between n_returned_values={ n_returned_values } , "
363+ f"dropped={ dropped } , tgt.elts={ tgt .elts } , tgt_elts={ tgt_elts } "
343364 )
365+ tgt = tgt_elts [0 ]
366+ else :
367+ assert n_returned_values == len (tgt_elts ), (
368+ f"Inconsistencies between n_returned_values={ n_returned_values } , "
369+ f"dropped={ dropped } , tgt.elts={ tgt .elts } , tgt_elts={ tgt_elts } "
370+ )
371+ tgt = ast .Tuple (tgt_elts , ctx = ast .Store ())
344372
345373 added = {tgt .id } if isinstance (tgt , ast .Name ) else set (t .id for t in tgt .elts )
346374 assign = ast .Assign (targets = [tgt ], value = call )
@@ -364,7 +392,7 @@ def visit_If(self, node):
364392 )
365393 then_expr = then_ret .value
366394 else_expr = else_ret .value
367- then_def , else_def , call , dropped = self ._rewrite_if (
395+ then_def , else_def , call , dropped , n_returned_values = self ._rewrite_if (
368396 node , [then_expr ], [else_expr ], known_local_variables = known_local_variables
369397 )
370398 ret = ast .Return (call )
@@ -809,7 +837,9 @@ def forward(self, x, y):
809837
810838@contextlib .contextmanager
811839def torch_export_rewrite (
812- rewrite : Optional [List [Union [Tuple [type , str ], Callable ]]] = None , verbose : int = 0
840+ rewrite : Optional [List [Union [Tuple [type , str ], Callable ]]] = None ,
841+ dump_rewriting : Optional [str ] = None ,
842+ verbose : int = 0 ,
813843):
814844 """
815845 Automatically rewrite the methods given in `rewrite` to export
@@ -818,6 +848,7 @@ def torch_export_rewrite(
818848 :param rewrite: methods of functions to rewrite, if not empty, the function may try
819849 to discover them, a method is defined by its class (a type) and its name
820850 if the class is local, by itself otherwise
851+ :param dump_rewriting: dumps rewriting information in file beginning with that prefix
821852 :param verbose: verbosity, up to 10, 10 shows the rewritten code,
822853 ``verbose=1`` shows the rewritten function,
823854 ``verbose=2`` shows the rewritten code as well
@@ -890,6 +921,7 @@ def forward(self, x, y):
890921 f"__globals__={ sorted (me .__globals__ )} "
891922 )
892923 mod = sys .modules [module ]
924+ cls_name = module
893925 cls = mod
894926 name = name
895927 to_rewrite = me
@@ -916,7 +948,19 @@ def forward(self, x, y):
916948 if verbose :
917949 print (f"[torch_export_rewrite] rewrites { kind } { cls .__name__ } .{ name } " )
918950 keep [cls , name ] = to_rewrite
951+ if dump_rewriting :
952+ filename = f"{ dump_rewriting } .{ kind } .{ cls_name } .{ name } .original.py"
953+ if verbose :
954+ print (f"[torch_export_rewrite] dump original code in { filename !r} " )
955+ with open (filename , "w" ) as f :
956+ f .write (inspect .getsource (to_rewrite ))
919957 rewr = transform_method (to_rewrite , verbose = max (verbose - 1 , 0 ))
958+ if dump_rewriting :
959+ filename = f"{ dump_rewriting } .{ kind } .{ cls_name } .{ name } .rewritten.py"
960+ if verbose :
961+ print (f"[torch_export_rewrite] dump rewritten code in { filename !r} " )
962+ with open (filename , "w" ) as f :
963+ f .write (rewr .code )
920964 setattr (cls , name , rewr .func )
921965
922966 try :
0 commit comments