Skip to content

Commit 401b149

Browse files
aorenstepytorchmergebot
authored andcommitted
[BE] typing for decorators - distributed/_tensor/ops/utils (pytorch#142139)
Test Plan: unit tests Differential Revision: D62302679 Pull Request resolved: pytorch#142139 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
1 parent 159b7ad commit 401b149

File tree

9 files changed

+25
-13
lines changed

9 files changed

+25
-13
lines changed

torch/distributed/tensor/_ops/_conv_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# Copyright (c) Meta Platforms, Inc. and affiliates
32
# implement matrix related ops for distributed tensor
43
from typing import List

torch/distributed/tensor/_ops/_embedding_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
# Copyright (c) Meta Platforms, Inc. and affiliates
43
# implement matrix related ops for distributed tensor

torch/distributed/tensor/_ops/_experimental_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# Copyright (c) Meta Platforms, Inc. and affiliates
32
# implement matrix related ops for distributed tensor
43

torch/distributed/tensor/_ops/_math_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
# Copyright (c) Meta Platforms, Inc. and affiliates
43
import math

torch/distributed/tensor/_ops/_matrix_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# Copyright (c) Meta Platforms, Inc. and affiliates
32
# implement matrix related ops for distributed tensor
43

torch/distributed/tensor/_ops/_random_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# Copyright (c) Meta Platforms, Inc. and affiliates
32
import torch
43
from torch.distributed.device_mesh import DeviceMesh

torch/distributed/tensor/_ops/_tensor_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
# Copyright (c) Meta Platforms, Inc. and affiliates
43
from typing import cast, List, Optional, Sequence, Sized, Tuple
@@ -593,7 +592,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding:
593592
args_schema=(
594593
schema_suggestion.args_schema[0],
595594
dim,
596-
schema_suggestion.args_schema[1][dim],
595+
schema_suggestion.args_schema[1][dim], # type: ignore[index]
597596
),
598597
kwargs_schema=op_schema.kwargs_schema,
599598
)

torch/distributed/tensor/_ops/_view_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
# Copyright (c) Meta Platforms, Inc. and affiliates
43
from dataclasses import dataclass

torch/distributed/tensor/_ops/utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
import functools
44
import itertools
55
import operator
6-
from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union
6+
from typing import (
7+
Callable,
8+
cast,
9+
Iterable,
10+
List,
11+
Optional,
12+
Sequence,
13+
Tuple,
14+
TypeVar,
15+
Union,
16+
)
17+
from typing_extensions import ParamSpec
718

819
import torch
920
from torch.distributed.tensor._api import DTensor
@@ -25,14 +36,21 @@
2536
)
2637

2738

39+
_T = TypeVar("_T")
40+
_P = ParamSpec("_P")
41+
42+
2843
# convenient wrapper to register sharding propagation rules
2944
# pyre-fixme[3]: Return type must be annotated.
3045
# pyre-fixme[2]: Parameter must be annotated.
31-
def register_prop_rule(op, schema_info=None):
46+
def register_prop_rule(
47+
op: Union[torch._ops.OpOverload, List[torch._ops.OpOverload]],
48+
schema_info: Optional[RuntimeSchemaInfo] = None,
49+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
3250
# pyre-fixme[53]: Captured variable `func` is not annotated.
3351
# pyre-fixme[3]: Return type must be annotated.
3452
# pyre-fixme[2]: Parameter must be annotated.
35-
def wrapper(impl):
53+
def wrapper(impl: Callable[_P, _T]) -> Callable[_P, _T]:
3654
overloads = op if isinstance(op, list) else [op]
3755
for overload in overloads:
3856
DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule(
@@ -43,7 +61,9 @@ def wrapper(impl):
4361
return wrapper
4462

4563

46-
def register_op_strategy(op, schema_info=None):
64+
def register_op_strategy(
65+
op, schema_info=None
66+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
4767
# pyre-fixme[53]: Captured variable `func` is not annotated.
4868
# pyre-fixme[3]: Return type must be annotated.
4969
# pyre-fixme[2]: Parameter must be annotated.

0 commit comments

Comments
 (0)