33import functools
44import itertools
55import operator
6- from typing import cast , Iterable , List , Optional , Sequence , Tuple , Union
6+ from typing import (
7+ Callable ,
8+ cast ,
9+ Iterable ,
10+ List ,
11+ Optional ,
12+ Sequence ,
13+ Tuple ,
14+ TypeVar ,
15+ Union ,
16+ )
17+ from typing_extensions import ParamSpec
718
819import torch
920from torch .distributed .tensor ._api import DTensor
2536)
2637
2738
39+ _T = TypeVar ("_T" )
40+ _P = ParamSpec ("_P" )
41+
42+
2843# convenient wrapper to register sharding propagation rules
2944# pyre-fixme[3]: Return type must be annotated.
3045# pyre-fixme[2]: Parameter must be annotated.
31- def register_prop_rule (op , schema_info = None ):
46+ def register_prop_rule (
47+ op : Union [torch ._ops .OpOverload , List [torch ._ops .OpOverload ]],
48+ schema_info : Optional [RuntimeSchemaInfo ] = None ,
49+ ) -> Callable [[Callable [_P , _T ]], Callable [_P , _T ]]:
3250 # pyre-fixme[53]: Captured variable `func` is not annotated.
3351 # pyre-fixme[3]: Return type must be annotated.
3452 # pyre-fixme[2]: Parameter must be annotated.
35- def wrapper (impl ) :
53+ def wrapper (impl : Callable [ _P , _T ]) -> Callable [ _P , _T ] :
3654 overloads = op if isinstance (op , list ) else [op ]
3755 for overload in overloads :
3856 DTensor ._op_dispatcher .sharding_propagator .register_sharding_prop_rule (
@@ -43,7 +61,9 @@ def wrapper(impl):
4361 return wrapper
4462
4563
46- def register_op_strategy (op , schema_info = None ):
64+ def register_op_strategy (
65+ op , schema_info = None
66+ ) -> Callable [[Callable [_P , _T ]], Callable [_P , _T ]]:
4767 # pyre-fixme[53]: Captured variable `func` is not annotated.
4868 # pyre-fixme[3]: Return type must be annotated.
4969 # pyre-fixme[2]: Parameter must be annotated.
0 commit comments