Skip to content

Commit 20ee5f9

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
remove allow-untyped-defs from elastic_distributed_sampler.py (pytorch#154620)
Pull Request resolved: pytorch#154620 Approved by: https://github.com/Skylion007
1 parent 9c06dff commit 20ee5f9

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

torch/distributed/elastic/utils/data/elastic_distributed_sampler.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# mypy: allow-untyped-defs
32

43
# Copyright (c) Facebook, Inc. and its affiliates.
54
# All rights reserved.
@@ -8,12 +7,20 @@
87
# LICENSE file in the root directory of this source tree.
98

109
import math
10+
from collections.abc import Iterator, Sized
11+
from typing import cast, Optional, TypeVar
1112

1213
import torch
14+
from torch.utils.data import Dataset
1315
from torch.utils.data.distributed import DistributedSampler
1416

1517

16-
class ElasticDistributedSampler(DistributedSampler):
18+
T = TypeVar("T")
19+
20+
__all__ = ["ElasticDistributedSampler"]
21+
22+
23+
class ElasticDistributedSampler(DistributedSampler[T]):
1724
"""
1825
Sampler that restricts data loading to a subset of
1926
the dataset for elastic training.
@@ -34,25 +41,39 @@ class ElasticDistributedSampler(DistributedSampler):
3441
start_index (optional): Which index of the dataset to start sampling from
3542
"""
3643

37-
def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
44+
def __init__(
45+
self,
46+
dataset: Dataset[T],
47+
num_replicas: Optional[int] = None,
48+
rank: Optional[int] = None,
49+
start_index: int = 0,
50+
):
3851
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
39-
if start_index >= len(dataset):
52+
if not isinstance(dataset, Sized):
53+
raise TypeError("Dataset must be an instance of collections.abc.Sized")
54+
55+
# Cast to Sized for mypy
56+
sized_dataset = cast(Sized, dataset)
57+
58+
if start_index >= len(sized_dataset):
4059
raise ValueError(
41-
f"Start index {start_index} should be less than dataset size {len(dataset)}"
60+
f"Start index {start_index} should be less than dataset size {len(sized_dataset)}"
4261
)
4362

4463
self.start_index = start_index
64+
sized_dataset = cast(Sized, self.dataset)
4565
self.num_samples = int(
46-
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
66+
math.ceil(float(len(sized_dataset) - self.start_index) / self.num_replicas)
4767
)
4868
self.total_size = self.num_samples * self.num_replicas
4969

50-
def __iter__(self):
70+
def __iter__(self) -> Iterator[T]:
5171
# deterministically shuffle based on epoch
5272
g = torch.Generator()
5373
g.manual_seed(self.epoch)
74+
sized_dataset = cast(Sized, self.dataset)
5475
indices = (
55-
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
76+
torch.randperm(len(sized_dataset) - self.start_index, generator=g)
5677
.add(self.start_index)
5778
.tolist()
5879
)
@@ -67,5 +88,5 @@ def __iter__(self):
6788

6889
return iter(indices)
6990

70-
def __len__(self):
91+
def __len__(self) -> int:
7192
return self.num_samples

0 commit comments

Comments
 (0)