1- # mypy: allow-untyped-defs
2- from typing import Callable , Optional
1+ from typing import Any , Callable , Optional , TypeVar
2+ from typing_extensions import ParamSpec , TypeVarTuple , Unpack
33
44from torch ._prims .context import TorchRefsMode
55from torch .fx import GraphModule
66from torch .fx .experimental .proxy_tensor import make_fx , wrapper_and_args_for_make_fx
77
88
9+ T = TypeVar ("T" )
10+ P = ParamSpec ("P" )
11+ Ts = TypeVarTuple ("Ts" )
12+
13+
914def execute (
1015 gm : GraphModule ,
11- * args ,
16+ * args : Unpack [ Ts ] ,
1217 executor : str = "aten" ,
1318 executor_parameters : Optional [dict ] = None ,
14- ):
19+ ) -> Any :
1520 """
1621 Prototype ATen executor.
1722
@@ -25,7 +30,7 @@ def execute(
2530 raise ValueError (msg )
2631
2732
28- def make_traced (fn : Callable ) :
33+ def make_traced (fn : Callable [ P , T ]) -> Callable [ P , T ] :
2934 """
3035 Returns a function that, when called, will
3136 trace its torch operations to prims and then
@@ -49,12 +54,14 @@ def foo(a, b):
4954 result = traced_foo(a, b, executor='aten')
5055 """
5156
52- def _traced (* args , executor = "aten" , ** kwargs ):
57+ def _traced (* args : P .args , ** kwargs : P .kwargs ) -> T :
58+ executor = str (kwargs .pop ("executor" , "aten" ))
59+
5360 # TODO: caching
5461 wrapped , all_args = wrapper_and_args_for_make_fx (fn , args , kwargs )
5562
5663 with TorchRefsMode ():
5764 gm = make_fx (wrapped )(all_args )
5865 return execute (gm , all_args , executor = executor )
5966
60- return _traced
67+ return _traced # type: ignore[return-value]
0 commit comments