Skip to content

Commit 027772d

Browse files
committed
mypy
1 parent 4a20331 commit 027772d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
import types
44
import textwrap
5-
from typing import Callable
5+
from typing import Callable, Dict
66
import torch
77

88
NODE_TYPES = tuple(
@@ -240,7 +240,7 @@ def generic_visit(self, node):
240240
class RewrittenMethod:
241241
"""
242242
Stores a rewritten method using
243-
:func:`onnx_diagnostic.torch_export_patches.path_module.transform_method>`.
243+
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method>`.
244244
"""
245245

246246
def __init__(self, tree, func):
@@ -286,7 +286,7 @@ def transform_method(func: Callable, wrapper_name="torch_cond") -> RewrittenMeth
286286
f"{ast.unparse(new_tree)}\n--TREE--\n"
287287
f"{ast.dump(new_tree, **kws)}"
288288
) from e
289-
namespace = {}
289+
namespace: Dict[type, type] = {}
290290
globs = func.__globals__.copy()
291291
globs["torch_cond"] = torch.cond
292292
exec(mod, globs, namespace)

0 commit comments

Comments
 (0)