@@ -46,10 +46,83 @@ def visit_FunctionDef(self, node):
4646 self .current_func_args = old_args
4747 return node
4848
49+ def _rewrite_if (self , node , then_expr , else_expr ):
50+ test_node = node .test
51+
52+ # extract free variables
53+ then_name = f"{ self .wrapper_name } _then_{ self .counter } "
54+ else_name = f"{ self .wrapper_name } _else_{ self .counter } "
55+ then_vars = sorted (
56+ {
57+ n .id
58+ for n in ast .walk (then_expr )
59+ if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
60+ }
61+ )
62+ else_vars = sorted (
63+ {
64+ n .id
65+ for n in ast .walk (else_expr )
66+ if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
67+ }
68+ )
69+
70+ then_else_vars = set (_ for _ in [* then_vars , * else_vars ] if _ != "torch" )
71+
72+ # build local funcs
73+ then_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
74+ then_def = ast .FunctionDef (
75+ name = then_name ,
76+ args = ast .arguments (
77+ posonlyargs = [],
78+ args = then_args ,
79+ kwonlyargs = [],
80+ kw_defaults = [],
81+ defaults = [],
82+ ),
83+ body = [ast .Return (then_expr )],
84+ decorator_list = [],
85+ returns = None ,
86+ )
87+ else_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
88+ else_def = ast .FunctionDef (
89+ name = else_name ,
90+ args = ast .arguments (
91+ posonlyargs = [],
92+ args = else_args ,
93+ kwonlyargs = [],
94+ kw_defaults = [],
95+ defaults = [],
96+ ),
97+ body = [ast .Return (else_expr )],
98+ decorator_list = [],
99+ returns = None ,
100+ )
101+ # fix locations
102+ for n in (then_def , else_def ):
103+ ast .copy_location (n , node )
104+ ast .fix_missing_locations (n )
105+ assert hasattr (n , "lineno" )
106+ # wrapper call and assignment
107+ then_else_args_list = ast .List (
108+ [ast .Name (id = v , ctx = ast .Load ()) for v in then_else_vars ],
109+ ctx = ast .Load (),
110+ )
111+ call = ast .Call (
112+ func = ast .Name (id = self .wrapper_name , ctx = ast .Load ()),
113+ args = [
114+ test_node ,
115+ ast .Name (id = then_name , ctx = ast .Load ()),
116+ ast .Name (id = else_name , ctx = ast .Load ()),
117+ then_else_args_list ,
118+ ],
119+ keywords = [],
120+ )
121+ return then_def , else_def , call
122+
49123 def visit_If (self , node ):
50124 # First recurse into subnodes
51125 node = self .generic_visit (node )
52- test_node = node .test
53126
54127 # Case 1: simple assignment in both branches
55128 if (
@@ -68,79 +141,9 @@ def visit_If(self, node):
68141 and tgt .id == else_assign .targets [0 ].id
69142 ):
70143 self .counter += 1
71- then_name = f"{ self .wrapper_name } _then_{ self .counter } "
72- else_name = f"{ self .wrapper_name } _else_{ self .counter } "
73144 then_expr = then_assign .value
74145 else_expr = else_assign .value
75- # extract free variables
76- then_vars = sorted (
77- {
78- n .id
79- for n in ast .walk (then_expr )
80- if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
81- }
82- )
83- else_vars = sorted (
84- {
85- n .id
86- for n in ast .walk (else_expr )
87- if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
88- }
89- )
90- # build local funcs
91- then_args = [ast .arg (arg = v , annotation = None ) for v in then_vars ]
92- then_def = ast .FunctionDef (
93- name = then_name ,
94- args = ast .arguments (
95- posonlyargs = [],
96- args = then_args ,
97- kwonlyargs = [],
98- kw_defaults = [],
99- defaults = [],
100- ),
101- body = [ast .Return (then_expr )],
102- decorator_list = [],
103- returns = None ,
104- )
105- else_args = [ast .arg (arg = v , annotation = None ) for v in else_vars ]
106- else_def = ast .FunctionDef (
107- name = else_name ,
108- args = ast .arguments (
109- posonlyargs = [],
110- args = else_args ,
111- kwonlyargs = [],
112- kw_defaults = [],
113- defaults = [],
114- ),
115- body = [ast .Return (else_expr )],
116- decorator_list = [],
117- returns = None ,
118- )
119- # fix locations
120- for n in (then_def , else_def ):
121- ast .copy_location (n , node )
122- ast .fix_missing_locations (n )
123- assert hasattr (n , "lineno" )
124- # wrapper call and assignment
125- then_args_tuple = ast .Tuple (
126- [ast .Name (id = v , ctx = ast .Load ()) for v in then_vars ],
127- ctx = ast .Load (),
128- )
129- else_args_tuple = ast .Tuple (
130- [ast .Name (id = v , ctx = ast .Load ()) for v in else_vars ],
131- ctx = ast .Load (),
132- )
133- call = ast .Call (
134- func = ast .Name (id = self .wrapper_name , ctx = ast .Load ()),
135- args = [
136- test_node ,
137- ast .Name (id = then_name , ctx = ast .Load ()),
138- ast .Name (id = else_name , ctx = ast .Load ()),
139- then_args_tuple ,
140- else_args_tuple ,
141- ],
142- keywords = [],
143- )
146+ then_def , else_def , call = self ._rewrite_if (node , then_expr , else_expr )
144147 assign = ast .Assign (targets = [tgt ], value = call )
145148 ast .copy_location (assign , node )
146149 ast .fix_missing_locations (assign )
@@ -159,74 +162,7 @@ def visit_If(self, node):
159162 then_expr = then_ret .value
160163 else_expr = else_ret .value
161164 self .counter += 1
162- then_name = f"{ self .wrapper_name } _then_{ self .counter } "
163- else_name = f"{ self .wrapper_name } _else_{ self .counter } "
164- # extract free variables
165- then_vars = sorted (
166- {
167- n .id
168- for n in ast .walk (then_expr )
169- if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
170- }
171- )
172- else_vars = sorted (
173- {
174- n .id
175- for n in ast .walk (else_expr )
176- if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
177- }
178- )
179-
180- then_else_vars = set (_ for _ in [* then_vars , * else_vars ] if _ != "torch" )
181-
182- # build local funcs
183- then_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
184- then_def = ast .FunctionDef (
185- name = then_name ,
186- args = ast .arguments (
187- posonlyargs = [],
188- args = then_args ,
189- kwonlyargs = [],
190- kw_defaults = [],
191- defaults = [],
192- ),
193- body = [ast .Return (then_expr )],
194- decorator_list = [],
195- returns = None ,
196- )
197- else_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
198- else_def = ast .FunctionDef (
199- name = else_name ,
200- args = ast .arguments (
201- posonlyargs = [],
202- args = else_args ,
203- kwonlyargs = [],
204- kw_defaults = [],
205- defaults = [],
206- ),
207- body = [ast .Return (else_expr )],
208- decorator_list = [],
209- returns = None ,
210- )
211- for n in (then_def , else_def ):
212- ast .copy_location (n , node )
213- ast .fix_missing_locations (n )
214- # wrapper call and return
215- then_else_args_list = ast .List (
216- [ast .Name (id = v , ctx = ast .Load ()) for v in then_else_vars ],
217- ctx = ast .Load (),
218- )
219-
220- call = ast .Call (
221- func = ast .Name (id = self .wrapper_name , ctx = ast .Load ()),
222- args = [
223- test_node ,
224- ast .Name (id = then_name , ctx = ast .Load ()),
225- ast .Name (id = else_name , ctx = ast .Load ()),
226- then_else_args_list ,
227- ],
228- keywords = [],
229- )
165+ then_def , else_def , call = self ._rewrite_if (node , then_expr , else_expr )
230166 ret = ast .Return (call )
231167 ast .copy_location (ret , node )
232168 ast .fix_missing_locations (ret )
@@ -260,24 +196,41 @@ def __repr__(self):
260196 return f"{ self .__class__ .__name__ } ({ self .func } )"
261197
262198
263- def transform_method (func : Callable , if_name = "torch_cond" ) -> RewrittenMethod :
199+ def transform_method (
200+ func : Callable , if_name = "torch_cond" , verbose : int = 0
201+ ) -> RewrittenMethod :
264202 """
265203 Returns a new function based on `func` where every test (if)
266204 is replaced by a call to :func:`torch.cond`.
267205
268206 :param func: method or function to rewrite
269207 :param if_name: function calling the test
208+ :param verbose: verbosity
270209 :return: rewritten method
271210 """
272211 # Retrieve source of the function
273212 src = inspect .getsource (func )
213+ if verbose :
214+ print (f"[transform_method] -- source -- { func } " )
215+ print (src )
274216 # Parse into AST
275217 tree = ast .parse (textwrap .dedent (src ))
218+ if verbose > 1 :
219+ print ("[transform_method] -- tree --" )
220+ print (ast .dump (tree , indent = 2 ))
276221 # Apply transformation
277222 transformer = RewriteControlFlow (if_name )
278223 new_tree = transformer .visit (tree )
224+ if verbose > 1 :
225+ print ("[transform_method] -- new tree --" )
226+ print (ast .dump (tree , indent = 2 ))
279227 ast .fix_missing_locations (new_tree )
280228 _settl (new_tree , 0 )
229+
230+ if verbose > 0 :
231+ print ("[transform_method] -- new code --" )
232+ code = ast .unparse (new_tree )
233+ print (code )
281234 try :
282235 mod = compile (new_tree , filename = "<ast>" , mode = "exec" )
283236 except TypeError as e :
0 commit comments