@@ -89,13 +89,20 @@ class RewriteControlFlow(ast.NodeTransformer):
8989 :param skip_objects: to skip variable names if included in that list
9090 such as modules
9191 :param args_names: defines the local variables
92+ :param filter_nodes: a function which is used to decide which node
93+ to rewrite, True by default
94+ :param pre_rewriter: a rewriter applied before the automated rewriting
95+ :param post_rewriter: a rewriter applied after the automated rewriting
9296 """
9397
9498 def __init__ (
9599 self ,
96100 prefix : str = "branch_cond" ,
97101 skip_objects : Optional [Dict [str , object ]] = None ,
98102 args_names : Optional [Set [str ]] = None ,
103+ filter_node : Optional [Callable ["ast.Node" , bool ]] = None ,
104+ pre_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
105+ post_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
99106 ):
100107 self .counter_test = 0
101108 self .counter_loop = 0
@@ -104,6 +111,9 @@ def __init__(
104111 self .skip_objects = skip_objects or {}
105112 self .args_names = args_names or set ()
106113 self .local_variables = self .args_names .copy ()
114+ self .filter_node = filter_node or (lambda _node : True )
115+ self .pre_rewriter = pre_rewriter or (lambda node : node )
116+ self .post_rewriter = post_rewriter or (lambda node : node )
107117
108118 def generic_visit (self , node ):
109119 return super ().generic_visit (node )
@@ -320,6 +330,11 @@ def _make_targets(self, node, then_assigns, else_assigns):
320330 return tgt , tgt_mapping
321331
322332 def visit_If (self , node ):
333+ if not self .filter_node (node ):
334+ return [node ]
335+
336+ node = self .pre_rewriter (node )
337+
323338 # First recurse into subnodes
324339 known_local_variables = self .local_variables .copy ()
325340 node = self .generic_visit (node )
@@ -380,7 +395,7 @@ def visit_If(self, node):
380395 ast .copy_location (assign , node )
381396 ast .fix_missing_locations (assign )
382397 self .local_variables = known_local_variables | added
383- return [then_def , else_def , assign ]
398+ return [self . post_rewriter ( n ) for n in [ then_def , else_def , assign ] ]
384399
385400 # Case 2: return in both branches, we assume both branches return the same results.
386401 then_ret = node .body [- 1 ]
@@ -403,7 +418,7 @@ def visit_If(self, node):
403418 ret = ast .Return (call )
404419 ast .copy_location (ret , node )
405420 ast .fix_missing_locations (ret )
406- return [then_def , else_def , ret ]
421+ return [self . post_rewriter ( n ) for n in [ then_def , else_def , ret ] ]
407422
408423 def _find_loop_vars (self , node ):
409424 assert isinstance (node , ast .For ), f"Unexpected type { type (node )} for node"
@@ -462,6 +477,11 @@ def _find_loop_vars(self, node):
462477 )
463478
464479 def visit_For (self , node ):
480+ if not self .filter_node (node ):
481+ return [node ]
482+
483+ node = self .pre_rewriter (node )
484+
465485 # For nested loops.
466486 self .generic_visit (node )
467487 # look for variables, loop, inputs and outputs of the body
@@ -622,7 +642,7 @@ def visit_For(self, node):
622642 ctx = ast .Store (),
623643 )
624644 assign = ast .Assign (targets = [target ], value = call )
625- return [func_def , assign ]
645+ return [self . post_rewriter ( func_def ), self . post_rewriter ( assign ) ]
626646
627647
628648class RewrittenMethod :
@@ -697,6 +717,9 @@ def transform_method(
697717 func : Callable ,
698718 prefix : str = "branch_cond" ,
699719 verbose : int = 0 ,
720+ filter_node : Optional [Callable ["ast.Node" , bool ]] = None ,
721+ pre_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
722+ post_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
700723) -> RewrittenMethod :
701724 """
702725 Returns a new function based on `func` where every test (if)
@@ -717,6 +740,9 @@ def transform_method(
717740 :param func: method or function to rewrite
718741 :param prefix: prefix used to create the functions for the branches
719742 :param verbose: verbosity
743+ :param filter_node: a function which tells which node to rewrite
744+ :param pre_rewriter: a rewriter applied before the automated rewriting
745+ :param post_rewriter: a rewriter applied after the automated rewriting
720746 :return: rewritten method
721747
722748 An example with **return**:
@@ -801,6 +827,9 @@ def forward(self, x, y):
801827 prefix = prefix ,
802828 skip_objects = modules ,
803829 args_names = set (sig .parameters ),
830+ filter_node = filter_node ,
831+ pre_rewriter = pre_rewriter ,
832+ post_rewriter = post_rewriter ,
804833 )
805834 normalize_assignment_in_test (tree )
806835 inplace_add_parent (tree )
@@ -912,7 +941,22 @@ def forward(self, x, y):
912941 cls , name = me
913942 to_rewrite = getattr (cls , name )
914943 kind = "method"
944+ kws = {}
915945 else :
946+ if isinstance (me , dict ):
947+ assert "function" in me and (
948+ "filter_node" in me or "pre_rewriter" in me or "post_rewriter" in me
949+ ), (
950+ f"If the rewriting code is defined as a dictionary, key "
951+ f"'function' must be defined, other arguments must be understood by "
952+ f"{ transform_method .__name__ } , "
953+ f"the given value is { me !r} ."
954+ )
955+ kws = me
956+ me = me ["function" ]
957+ del kws ["function" ]
958+ else :
959+ kws = {}
916960 name = me .__qualname__
917961 spl = name .split ("." )
918962 if len (spl ) == 1 :
@@ -958,8 +1002,9 @@ def forward(self, x, y):
9581002 if verbose :
9591003 print (f"[torch_export_rewrite] dump original code in { filename !r} " )
9601004 with open (filename , "w" ) as f :
961- f .write (inspect .getsource (to_rewrite ))
962- rewr = transform_method (to_rewrite , verbose = max (verbose - 1 , 0 ))
1005+ code = inspect .getsource (to_rewrite )
1006+ f .write (code )
1007+ rewr = transform_method (to_rewrite , verbose = max (verbose - 1 , 0 ), ** kws )
9631008 if dump_rewriting :
9641009 filename = f"{ dump_rewriting } .{ kind } .{ cls_name } .{ name } .rewritten.py"
9651010 if verbose :
0 commit comments