Skip to content

Commit 096cb87

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
remove allow-untyped-defs from torch/_prims/executor.py (pytorch#144233)
Pull Request resolved: pytorch#144233 Approved by: https://github.com/Skylion007
1 parent 0aa74d0 commit 096cb87

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

torch/_prims/executor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1-
# mypy: allow-untyped-defs
2-
from typing import Callable, Optional
1+
from typing import Any, Callable, Optional, TypeVar
2+
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
33

44
from torch._prims.context import TorchRefsMode
55
from torch.fx import GraphModule
66
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
77

88

9+
T = TypeVar("T")
10+
P = ParamSpec("P")
11+
Ts = TypeVarTuple("Ts")
12+
13+
914
def execute(
1015
gm: GraphModule,
11-
*args,
16+
*args: Unpack[Ts],
1217
executor: str = "aten",
1318
executor_parameters: Optional[dict] = None,
14-
):
19+
) -> Any:
1520
"""
1621
Prototype ATen executor.
1722
@@ -25,7 +30,7 @@ def execute(
2530
raise ValueError(msg)
2631

2732

28-
def make_traced(fn: Callable):
33+
def make_traced(fn: Callable[P, T]) -> Callable[P, T]:
2934
"""
3035
Returns a function that, when called, will
3136
trace its torch operations to prims and then
@@ -49,12 +54,14 @@ def foo(a, b):
4954
result = traced_foo(a, b, executor='aten')
5055
"""
5156

52-
def _traced(*args, executor="aten", **kwargs):
57+
def _traced(*args: P.args, **kwargs: P.kwargs) -> T:
58+
executor = str(kwargs.pop("executor", "aten"))
59+
5360
# TODO: caching
5461
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
5562

5663
with TorchRefsMode():
5764
gm = make_fx(wrapped)(all_args)
5865
return execute(gm, all_args, executor=executor)
5966

60-
return _traced
67+
return _traced # type: ignore[return-value]

0 commit comments

Comments
 (0)