@@ -240,7 +240,10 @@ def generic_visit(self, node):
240240class RewrittenMethod :
241241 """
242242 Stores a rewritten method using
243- :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method>`.
243+ :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
244+
245+ :param tree: ast tree
246+ :param func: callable compiled from the tree
244247 """
245248
246249 def __init__ (self , tree , func ):
@@ -253,20 +256,25 @@ def code(self) -> str:
253256 return ast .unparse (self .tree )
254257
255258 def __repr__ (self ):
259+ "usual"
256260 return f"{ self .__class__ .__name__ } ({ self .func } )"
257261
258262
259- def transform_method (func : Callable , wrapper_name = "torch_cond" ) -> RewrittenMethod :
263+ def transform_method (func : Callable , if_name = "torch_cond" ) -> RewrittenMethod :
260264 """
261265 Returns a new function based on `func` where every test (if)
262266 is replaced by a call to :func:`torch.cond`.
267+
268+ :param func: method or function to rewrite
269+ :param if_name: function calling the test
270+ :return: rewritten method
263271 """
264272 # Retrieve source of the function
265273 src = inspect .getsource (func )
266274 # Parse into AST
267275 tree = ast .parse (textwrap .dedent (src ))
268276 # Apply transformation
269- transformer = RewriteControlFlow (wrapper_name )
277+ transformer = RewriteControlFlow (if_name )
270278 new_tree = transformer .visit (tree )
271279 ast .fix_missing_locations (new_tree )
272280 _settl (new_tree , 0 )
@@ -286,9 +294,9 @@ def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMeth
286294 f"{ ast .unparse (new_tree )} \n --TREE--\n "
287295 f"{ ast .dump (new_tree , ** kws )} "
288296 ) from e
289- namespace : Dict [type , type ] = {}
297+ namespace : Dict [str , type ] = {}
290298 globs = func .__globals__ .copy ()
291- globs ["torch_cond" ] = torch .cond
299+ globs [if_name ] = torch .cond
292300 exec (mod , globs , namespace )
293301 new_func = namespace .get (func .__name__ )
294302 if not isinstance (new_func , types .FunctionType ):
0 commit comments