11#!/usr/bin/env python3
2- # mypy: allow-untyped-defs
32
43# Copyright (c) Facebook, Inc. and its affiliates.
54# All rights reserved.
87# LICENSE file in the root directory of this source tree.
98
109import math
10+ from collections .abc import Iterator , Sized
11+ from typing import cast , Optional , TypeVar
1112
1213import torch
14+ from torch .utils .data import Dataset
1315from torch .utils .data .distributed import DistributedSampler
1416
1517
16- class ElasticDistributedSampler (DistributedSampler ):
18+ T = TypeVar ("T" )
19+
20+ __all__ = ["ElasticDistributedSampler" ]
21+
22+
23+ class ElasticDistributedSampler (DistributedSampler [T ]):
1724 """
1825 Sampler that restricts data loading to a subset of
1926 the dataset for elastic training.
@@ -34,25 +41,39 @@ class ElasticDistributedSampler(DistributedSampler):
3441 start_index (optional): Which index of the dataset to start sampling from
3542 """
3643
37- def __init__ (self , dataset , num_replicas = None , rank = None , start_index = 0 ):
44+ def __init__ (
45+ self ,
46+ dataset : Dataset [T ],
47+ num_replicas : Optional [int ] = None ,
48+ rank : Optional [int ] = None ,
49+ start_index : int = 0 ,
50+ ):
3851 super ().__init__ (dataset = dataset , num_replicas = num_replicas , rank = rank )
39- if start_index >= len (dataset ):
52+ if not isinstance (dataset , Sized ):
53+ raise TypeError ("Dataset must be an instance of collections.abc.Sized" )
54+
55+ # Cast to Sized for mypy
56+ sized_dataset = cast (Sized , dataset )
57+
58+ if start_index >= len (sized_dataset ):
4059 raise ValueError (
41- f"Start index { start_index } should be less than dataset size { len (dataset )} "
60+ f"Start index { start_index } should be less than dataset size { len (sized_dataset )} "
4261 )
4362
4463 self .start_index = start_index
64+ sized_dataset = cast (Sized , self .dataset )
4565 self .num_samples = int (
46- math .ceil (float (len (self . dataset ) - self .start_index ) / self .num_replicas ) # type: ignore[arg-type]
66+ math .ceil (float (len (sized_dataset ) - self .start_index ) / self .num_replicas )
4767 )
4868 self .total_size = self .num_samples * self .num_replicas
4969
50- def __iter__ (self ):
70+ def __iter__ (self ) -> Iterator [ T ] :
5171 # deterministically shuffle based on epoch
5272 g = torch .Generator ()
5373 g .manual_seed (self .epoch )
74+ sized_dataset = cast (Sized , self .dataset )
5475 indices = (
55- torch .randperm (len (self . dataset ) - self .start_index , generator = g ) # type: ignore[arg-type]
76+ torch .randperm (len (sized_dataset ) - self .start_index , generator = g )
5677 .add (self .start_index )
5778 .tolist ()
5879 )
@@ -67,5 +88,5 @@ def __iter__(self):
6788
6889 return iter (indices )
6990
70- def __len__ (self ):
91+ def __len__ (self ) -> int :
7192 return self .num_samples
0 commit comments