33import types
44import textwrap
55from typing import Callable , Dict , List , Set , Optional
6- import torch
76
87NODE_TYPES = tuple (
98 getattr (ast , k )
@@ -34,22 +33,20 @@ def _settl(node, lineno, level=0):
3433
3534class RewriteControlFlow (ast .NodeTransformer ):
3635 """
37- The class rewrites tests with function ``torch_cond`` :func:`torch.cond`.
36+ The class rewrites tests with function :func:`torch.cond`.
3837 ``empty_tensor`` is a function returning an empty tensor,
3938 when a branch returns something the other branch does not.
4039 """
4140
4241 def __init__ (
4342 self ,
44- wrapper_name ,
45- empty_tensor : str = "make_empty_tensor" ,
43+ prefix : str = "branch_cond" ,
4644 skip_objects : Optional [Dict [str , object ]] = None ,
4745 args_names : Optional [Set [str ]] = None ,
4846 ):
49- self .wrapper_name = wrapper_name
50- self .empty_tensor = empty_tensor
5147 self .counter = 0
5248 self .current_func_args = None
49+ self .prefix = prefix
5350 self .skip_objects = skip_objects or {}
5451 self .args_names = args_names or set ()
5552 self .local_variables = self .args_names .copy ()
@@ -83,7 +80,7 @@ def _find_id(self, exprs: List["ast.Node"]) -> List[str]:
8380 for n in ast .walk (expr ):
8481 if (
8582 isinstance (n , ast .Name )
86- and isinstance (n .ctx , ast .Store )
83+ and isinstance (n .ctx , ast .Load )
8784 and n .id not in self .skip_objects
8885 ):
8986 vars .append (n .id )
@@ -97,8 +94,8 @@ def _rewrite_if(
9794 drop = set ()
9895
9996 # extract free variables
100- then_name = f"{ self .wrapper_name } _then_{ self .counter } "
101- else_name = f"{ self .wrapper_name } _else_{ self .counter } "
97+ then_name = f"{ self .prefix } _then_{ self .counter } "
98+ else_name = f"{ self .prefix } _else_{ self .counter } "
10299 then_vars = self ._find_id (then_exprs )
103100 else_vars = self ._find_id (else_exprs )
104101 then_else_vars = set (_ for _ in [* then_vars , * else_vars ] if _ in known_local_variables )
@@ -173,8 +170,11 @@ def _rewrite_if(
173170 [ast .Name (id = v , ctx = ast .Load ()) for v in then_else_vars ],
174171 ctx = ast .Load (),
175172 )
173+
176174 call = ast .Call (
177- func = ast .Name (id = self .wrapper_name , ctx = ast .Load ()),
175+ func = ast .Attribute (
176+ value = ast .Name (id = "torch" , ctx = ast .Load ()), attr = "cond" , ctx = ast .Load ()
177+ ),
178178 args = [
179179 test_node ,
180180 ast .Name (id = then_name , ctx = ast .Load ()),
@@ -346,8 +346,7 @@ def inplace_add_parent(tree: "ast.Node"):
346346
347347def transform_method (
348348 func : Callable ,
349- if_name : str = "torch_cond" ,
350- empty_tensor : str = "make_empty_tensor" ,
349+ prefix : str = "branch_cond" ,
351350 verbose : int = 0 ,
352351) -> RewrittenMethod :
353352 """
@@ -367,8 +366,7 @@ def transform_method(
367366 See method ``_filter_target``.
368367
369368 :param func: method or function to rewrite
370- :param if_name: function calling the test
371- :param empty_tensor: function creating an empty tensor
369+ :param prefix: prefix used to create the functions for the branches
372370 :param verbose: verbosity
373371 :return: rewritten method
374372
@@ -390,7 +388,7 @@ def forward(self, x, y):
390388 x, y = torch.rand((3, 4)), torch.rand((3, 4))
391389 expected = Model()(x, y)
392390
393- rewritten = transform_method(Model.forward, verbose=10 )
391+ rewritten = transform_method(Model.forward)
394392 print("-- code --")
395393 print(rewritten.code)
396394
@@ -423,7 +421,7 @@ def forward(self, x, y):
423421 x, y = torch.rand((3, 4)), torch.rand((3, 4))
424422 expected = Model()(x, y)
425423
426- rewritten = transform_method(Model.forward, verbose=10 )
424+ rewritten = transform_method(Model.forward)
427425 print("-- code --")
428426 print(rewritten.code)
429427
@@ -447,8 +445,7 @@ def forward(self, x, y):
447445 print (f"[transform_method] -- tree --\n \n { ast .dump (tree , indent = 2 )} " )
448446 # Apply transformation
449447 transformer = RewriteControlFlow (
450- if_name ,
451- empty_tensor = empty_tensor ,
448+ prefix = prefix ,
452449 skip_objects = modules ,
453450 args_names = set (sig .parameters ),
454451 )
@@ -482,8 +479,6 @@ def forward(self, x, y):
482479 ) from e
483480 namespace : Dict [str , type ] = {}
484481 globs = func .__globals__ .copy ()
485- globs [if_name ] = torch .cond
486- globs [empty_tensor ] = lambda : torch .tensor ([])
487482 exec (mod , globs , namespace )
488483 new_func = namespace .get (func .__name__ )
489484 if not isinstance (new_func , types .FunctionType ):
0 commit comments