22import inspect
33import types
44import textwrap
5+ from typing import Callable
6+ import torch
7+
8+ NODE_TYPES = tuple (
9+ getattr (ast , k )
10+ for k in dir (ast )
11+ if "A" <= k [0 ] <= "Z" and isinstance (getattr (ast , k ), type )
12+ )
13+
14+
15+ def _settl (node , lineno , level = 0 ):
16+ if isinstance (node , (str , int , float )):
17+ return node
18+ if isinstance (node , list ):
19+ for n in node :
20+ _settl (n , lineno , level = level + 1 )
21+ return node
22+ if isinstance (node , NODE_TYPES ):
23+ if not hasattr (node , "lineno" ) or node .lineno is None :
24+ node .lineno = lineno
25+ for k in dir (node ):
26+ if k in {"s" , "n" }:
27+ continue
28+ if k [0 ] == "_" :
29+ continue
30+ v = getattr (node , k )
31+ _settl (v , max (lineno , node .lineno ), level = level + 1 )
32+ return node
533
634
735class RewriteControlFlow (ast .NodeTransformer ):
@@ -22,6 +50,7 @@ def visit_If(self, node):
2250 # First recurse into subnodes
2351 node = self .generic_visit (node )
2452 test_node = node .test
53+
2554 # Case 1: simple assignment in both branches
2655 if (
2756 len (node .body ) == 1
@@ -91,12 +120,15 @@ def visit_If(self, node):
91120 for n in (then_def , else_def ):
92121 ast .copy_location (n , node )
93122 ast .fix_missing_locations (n )
123+ assert hasattr (n , "lineno" )
94124 # wrapper call and assignment
95125 then_args_tuple = ast .Tuple (
96- [ast .Name (id = v , ctx = ast .Load ()) for v in then_vars ], ctx = ast .Load ()
126+ [ast .Name (id = v , ctx = ast .Load ()) for v in then_vars ],
127+ ctx = ast .Load (),
97128 )
98129 else_args_tuple = ast .Tuple (
99- [ast .Name (id = v , ctx = ast .Load ()) for v in else_vars ], ctx = ast .Load ()
130+ [ast .Name (id = v , ctx = ast .Load ()) for v in else_vars ],
131+ ctx = ast .Load (),
100132 )
101133 call = ast .Call (
102134 func = ast .Name (id = self .wrapper_name , ctx = ast .Load ()),
@@ -113,6 +145,7 @@ def visit_If(self, node):
113145 ast .copy_location (assign , node )
114146 ast .fix_missing_locations (assign )
115147 return [then_def , else_def , assign ]
148+
116149 # Case 2: simple return in both branches
117150 if (
118151 len (node .body ) == 1
@@ -143,22 +176,33 @@ def visit_If(self, node):
143176 if isinstance (n , ast .Name ) and isinstance (n .ctx , ast .Load )
144177 }
145178 )
179+
180+ then_else_vars = set (_ for _ in [* then_vars , * else_vars ] if _ != "torch" )
181+
146182 # build local funcs
147- then_args = [ast .arg (arg = v , annotation = None ) for v in then_vars ]
183+ then_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
148184 then_def = ast .FunctionDef (
149185 name = then_name ,
150186 args = ast .arguments (
151- posonlyargs = [], args = then_args , kwonlyargs = [], kw_defaults = [], defaults = []
187+ posonlyargs = [],
188+ args = then_args ,
189+ kwonlyargs = [],
190+ kw_defaults = [],
191+ defaults = [],
152192 ),
153193 body = [ast .Return (then_expr )],
154194 decorator_list = [],
155195 returns = None ,
156196 )
157- else_args = [ast .arg (arg = v , annotation = None ) for v in else_vars ]
197+ else_args = [ast .arg (arg = v , annotation = None ) for v in then_else_vars ]
158198 else_def = ast .FunctionDef (
159199 name = else_name ,
160200 args = ast .arguments (
161- posonlyargs = [], args = else_args , kwonlyargs = [], kw_defaults = [], defaults = []
201+ posonlyargs = [],
202+ args = else_args ,
203+ kwonlyargs = [],
204+ kw_defaults = [],
205+ defaults = [],
162206 ),
163207 body = [ast .Return (else_expr )],
164208 decorator_list = [],
@@ -168,20 +212,18 @@ def visit_If(self, node):
168212 ast .copy_location (n , node )
169213 ast .fix_missing_locations (n )
170214 # wrapper call and return
171- then_args_tuple = ast .Tuple (
172- [ast .Name (id = v , ctx = ast .Load ()) for v in then_vars ], ctx = ast .Load ()
173- )
174- else_args_tuple = ast .Tuple (
175- [ast .Name (id = v , ctx = ast .Load ()) for v in else_vars ], ctx = ast .Load ()
215+ then_else_args_list = ast .List (
216+ [ast .Name (id = v , ctx = ast .Load ()) for v in then_else_vars ],
217+ ctx = ast .Load (),
176218 )
219+
177220 call = ast .Call (
178221 func = ast .Name (id = self .wrapper_name , ctx = ast .Load ()),
179222 args = [
180223 test_node ,
181224 ast .Name (id = then_name , ctx = ast .Load ()),
182225 ast .Name (id = else_name , ctx = ast .Load ()),
183- then_args_tuple ,
184- else_args_tuple ,
226+ then_else_args_list ,
185227 ],
186228 keywords = [],
187229 )
@@ -195,24 +237,29 @@ def generic_visit(self, node):
195237 return super ().generic_visit (node )
196238
197239
198- def _fix_missing_locations_node (node ):
199- if not hasattr (node , "lineno" ):
200- node .lineno = 999
201- for chi in ast .iter_child_nodes (node ):
202- _fix_missing_locations_node (chi )
240+ class RewrittenMethod :
241+ """
242+ Stores a rewritten method using
243+ :func:`onnx_diagnostic.torch_export_patches.path_module.transform_method>`.
244+ """
245+
246+ def __init__ (self , tree , func ):
247+ self .tree = tree
248+ self .func = func
203249
250+ @property
251+ def code (self ) -> str :
252+ """Returns the source."""
253+ return ast .unparse (self .tree )
204254
205- def _fix_missing_locations (new_tree ):
206- for node in ast .walk (new_tree ):
207- _fix_missing_locations_node (node )
255+ def __repr__ (self ):
256+ return f"{ self .__class__ .__name__ } ({ self .func } )"
208257
209258
210- def transform_method (func , wrapper_name = "torch_cond" ):
259+ def transform_method (func : Callable , wrapper_name = "torch_cond" ) -> RewrittenMethod :
211260 """
212- Returns a new function based on `func` where every test (if, while, assert,
213- ternary, comparison, boolean op) is replaced by a call to `wrapper_name`.
214-
215- wrapper_name should refer to a function taking a single boolean argument.
261+ Returns a new function based on `func` where every test (if)
262+ is replaced by a call to :func:`torch.cond`.
216263 """
217264 # Retrieve source of the function
218265 src = inspect .getsource (func )
@@ -222,13 +269,28 @@ def transform_method(func, wrapper_name="torch_cond"):
222269 transformer = RewriteControlFlow (wrapper_name )
223270 new_tree = transformer .visit (tree )
224271 ast .fix_missing_locations (new_tree )
225-
226- # fix other location
227- _fix_missing_locations (new_tree )
228- mod = compile (new_tree , filename = "<ast>" , mode = "exec" )
272+ _settl (new_tree , 0 )
273+ try :
274+ mod = compile (new_tree , filename = "<ast>" , mode = "exec" )
275+ except TypeError as e :
276+ if 'required field "lineno" missing from stmt' in str (e ):
277+ # Could not find a way to avoid compilng a string.
278+ # The error message still pops up without indicating which node is not
279+ # properly set.
280+ code = ast .unparse (new_tree )
281+ mod = compile (code , filename = "<source>" , mode = "exec" )
282+ else :
283+ kws = dict (include_attributes = True , annotate_fields = True , indent = 4 )
284+ raise RuntimeError (
285+ f"Unable to compile code\n --CODE--\n "
286+ f"{ ast .unparse (new_tree )} \n --TREE--\n "
287+ f"{ ast .dump (new_tree , ** kws )} "
288+ ) from e
229289 namespace = {}
230- exec (mod , func .__globals__ , namespace )
290+ globs = func .__globals__ .copy ()
291+ globs ["torch_cond" ] = torch .cond
292+ exec (mod , globs , namespace )
231293 new_func = namespace .get (func .__name__ )
232294 if not isinstance (new_func , types .FunctionType ):
233295 raise RuntimeError ("Transformed function not found" )
234- return new_tree , new_func
296+ return RewrittenMethod ( new_tree , new_func )
0 commit comments