Skip to content

Commit dbd3d62

Browse files
committed
mypy
1 parent 93a8445 commit dbd3d62

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def __init__(
100100
prefix: str = "branch_cond",
101101
skip_objects: Optional[Dict[str, object]] = None,
102102
args_names: Optional[Set[str]] = None,
103-
filter_node: Optional[Callable["ast.Node", bool]] = None,
104-
pre_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
105-
post_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
103+
filter_node: Optional[Callable[["ast.Node"], bool]] = None,
104+
pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
105+
post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
106106
):
107107
self.counter_test = 0
108108
self.counter_loop = 0
@@ -717,9 +717,9 @@ def transform_method(
717717
func: Callable,
718718
prefix: str = "branch_cond",
719719
verbose: int = 0,
720-
filter_node: Optional[Callable["ast.Node", bool]] = None,
721-
pre_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
722-
post_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
720+
filter_node: Optional[Callable[["ast.Node"], bool]] = None,
721+
pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
722+
post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
723723
) -> RewrittenMethod:
724724
"""
725725
Returns a new function based on `func` where every test (if)
@@ -941,7 +941,7 @@ def forward(self, x, y):
941941
cls, name = me
942942
to_rewrite = getattr(cls, name)
943943
kind = "method"
944-
kws = {}
944+
kws = {} # type: ignore[var-annotated]
945945
else:
946946
if isinstance(me, dict):
947947
assert "function" in me and (

0 commit comments

Comments
 (0)