Skip to content

Commit 420977d

Browse files
committed
add
1 parent 027772d commit 420977d

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ def generic_visit(self, node):
240240
class RewrittenMethod:
241241
"""
242242
Stores a rewritten method using
243-
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method>`.
243+
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
244+
245+
:param tree: ast tree
246+
:param func: callable compiled from the tree
244247
"""
245248

246249
def __init__(self, tree, func):
@@ -253,20 +256,25 @@ def code(self) -> str:
253256
return ast.unparse(self.tree)
254257

255258
def __repr__(self):
259+
"usual"
256260
return f"{self.__class__.__name__}({self.func})"
257261

258262

259-
def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMethod:
263+
def transform_method(func: Callable, if_name="torch_cond") -> RewrittenMethod:
260264
"""
261265
Returns a new function based on `func` where every test (if)
262266
is replaced by a call to :func:`torch.cond`.
267+
268+
:param func: method or function to rewrite
269+
:param if_name: function calling the test
270+
:return: rewritten method
263271
"""
264272
# Retrieve source of the function
265273
src = inspect.getsource(func)
266274
# Parse into AST
267275
tree = ast.parse(textwrap.dedent(src))
268276
# Apply transformation
269-
transformer = RewriteControlFlow(wrapper_name)
277+
transformer = RewriteControlFlow(if_name)
270278
new_tree = transformer.visit(tree)
271279
ast.fix_missing_locations(new_tree)
272280
_settl(new_tree, 0)
@@ -286,9 +294,9 @@ def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMeth
286294
f"{ast.unparse(new_tree)}\n--TREE--\n"
287295
f"{ast.dump(new_tree, **kws)}"
288296
) from e
289-
namespace: Dict[type, type] = {}
297+
namespace: Dict[str, type] = {}
290298
globs = func.__globals__.copy()
291-
globs["torch_cond"] = torch.cond
299+
globs[if_name] = torch.cond
292300
exec(mod, globs, namespace)
293301
new_func = namespace.get(func.__name__)
294302
if not isinstance(new_func, types.FunctionType):

0 commit comments

Comments
 (0)