@@ -100,9 +100,9 @@ def __init__(
100100 prefix : str = "branch_cond" ,
101101 skip_objects : Optional [Dict [str , object ]] = None ,
102102 args_names : Optional [Set [str ]] = None ,
103- filter_node : Optional [Callable ["ast.Node" , bool ]] = None ,
104- pre_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
105- post_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
103+ filter_node : Optional [Callable [[ "ast.Node" ] , bool ]] = None ,
104+ pre_rewriter : Optional [Callable [[ "ast.Node" ] , "ast.Node" ]] = None ,
105+ post_rewriter : Optional [Callable [[ "ast.Node" ] , "ast.Node" ]] = None ,
106106 ):
107107 self .counter_test = 0
108108 self .counter_loop = 0
@@ -717,9 +717,9 @@ def transform_method(
717717 func : Callable ,
718718 prefix : str = "branch_cond" ,
719719 verbose : int = 0 ,
720- filter_node : Optional [Callable ["ast.Node" , bool ]] = None ,
721- pre_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
722- post_rewriter : Optional [Callable ["ast.Node" , "ast.Node" ]] = None ,
720+ filter_node : Optional [Callable [[ "ast.Node" ] , bool ]] = None ,
721+ pre_rewriter : Optional [Callable [[ "ast.Node" ] , "ast.Node" ]] = None ,
722+ post_rewriter : Optional [Callable [[ "ast.Node" ] , "ast.Node" ]] = None ,
723723) -> RewrittenMethod :
724724 """
725725 Returns a new function based on `func` where every test (if)
@@ -941,7 +941,7 @@ def forward(self, x, y):
941941 cls , name = me
942942 to_rewrite = getattr (cls , name )
943943 kind = "method"
944- kws = {}
944+ kws = {} # type: ignore[var-annotated]
945945 else :
946946 if isinstance (me , dict ):
947947 assert "function" in me and (
0 commit comments