Skip to content

Commit 90e81a1

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Migrate from Tuple -> tuple in torch/utils/data (pytorch#144255)
Pull Request resolved: pytorch#144255 Approved by: https://github.com/andrewkho
1 parent 8ccf3f6 commit 90e81a1

File tree

12 files changed

+45
-47
lines changed

12 files changed

+45
-47
lines changed

torch/utils/data/_utils/collate.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import contextlib
1313
import copy
1414
import re
15-
from typing import Callable, Dict, Optional, Tuple, Type, Union
15+
from typing import Callable, Dict, Optional, Type, Union
1616

1717
import torch
1818

@@ -118,7 +118,7 @@ def default_convert(data):
118118
def collate(
119119
batch,
120120
*,
121-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
121+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
122122
):
123123
r"""
124124
General collate function that handles collection type of element within each batch.
@@ -243,7 +243,7 @@ def collate(
243243
def collate_tensor_fn(
244244
batch,
245245
*,
246-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
246+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
247247
):
248248
elem = batch[0]
249249
out = None
@@ -275,7 +275,7 @@ def collate_tensor_fn(
275275
def collate_numpy_array_fn(
276276
batch,
277277
*,
278-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
278+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
279279
):
280280
elem = batch[0]
281281
# array of string classes and object
@@ -288,36 +288,36 @@ def collate_numpy_array_fn(
288288
def collate_numpy_scalar_fn(
289289
batch,
290290
*,
291-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
291+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
292292
):
293293
return torch.as_tensor(batch)
294294

295295

296296
def collate_float_fn(
297297
batch,
298298
*,
299-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
299+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
300300
):
301301
return torch.tensor(batch, dtype=torch.float64)
302302

303303

304304
def collate_int_fn(
305305
batch,
306306
*,
307-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
307+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
308308
):
309309
return torch.tensor(batch)
310310

311311

312312
def collate_str_fn(
313313
batch,
314314
*,
315-
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
315+
collate_fn_map: Optional[Dict[Union[Type, tuple[Type, ...]], Callable]] = None,
316316
):
317317
return batch
318318

319319

320-
default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
320+
default_collate_fn_map: Dict[Union[Type, tuple[Type, ...]], Callable] = {
321321
torch.Tensor: collate_tensor_fn
322322
}
323323
with contextlib.suppress(ImportError):

torch/utils/data/datapipes/gen_pyi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import pathlib
44
from collections import defaultdict
5-
from typing import Any, Dict, List, Set, Tuple, Union
5+
from typing import Any, Dict, List, Set, Union
66

77

88
def materialize_lines(lines: List[str], indentation: int) -> str:
@@ -19,7 +19,7 @@ def gen_from_template(
1919
dir: str,
2020
template_name: str,
2121
output_name: str,
22-
replacements: List[Tuple[str, Any, int]],
22+
replacements: List[tuple[str, Any, int]],
2323
):
2424
template_path = os.path.join(dir, template_name)
2525
output_path = os.path.join(dir, output_name)
@@ -75,7 +75,7 @@ def extract_class_name(line: str) -> str:
7575

7676
def parse_datapipe_file(
7777
file_path: str,
78-
) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
78+
) -> tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
7979
"""Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
8080
method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
8181
doc_string_dict = defaultdict(list)
@@ -127,7 +127,7 @@ def parse_datapipe_file(
127127

128128
def parse_datapipe_files(
129129
file_paths: Set[str],
130-
) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
130+
) -> tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
131131
(
132132
methods_and_signatures,
133133
methods_and_class_names,

torch/utils/data/datapipes/iter/combining.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Literal,
1313
Optional,
1414
Sized,
15-
Tuple,
1615
TypeVar,
1716
)
1817

@@ -54,7 +53,7 @@ class ConcaterIterDataPipe(IterDataPipe):
5453
[0, 1, 2, 0, 1, 2, 3, 4]
5554
"""
5655

57-
datapipes: Tuple[IterDataPipe]
56+
datapipes: tuple[IterDataPipe]
5857

5958
def __init__(self, *datapipes: IterDataPipe):
6059
if len(datapipes) == 0:
@@ -668,7 +667,7 @@ def __del__(self):
668667

669668

670669
@functional_datapipe("zip")
671-
class ZipperIterDataPipe(IterDataPipe[Tuple[_T_co]]):
670+
class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
672671
r"""
673672
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
674673
@@ -685,7 +684,7 @@ class ZipperIterDataPipe(IterDataPipe[Tuple[_T_co]]):
685684
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
686685
"""
687686

688-
datapipes: Tuple[IterDataPipe]
687+
datapipes: tuple[IterDataPipe]
689688

690689
def __init__(self, *datapipes: IterDataPipe):
691690
if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
@@ -695,7 +694,7 @@ def __init__(self, *datapipes: IterDataPipe):
695694
super().__init__()
696695
self.datapipes = datapipes # type: ignore[assignment]
697696

698-
def __iter__(self) -> Iterator[Tuple[_T_co]]:
697+
def __iter__(self) -> Iterator[tuple[_T_co]]:
699698
iterators = [iter(datapipe) for datapipe in self.datapipes]
700699
yield from zip(*iterators)
701700

torch/utils/data/datapipes/iter/fileopener.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# mypy: allow-untyped-defs
22
from io import IOBase
3-
from typing import Iterable, Optional, Tuple
3+
from typing import Iterable, Optional
44

55
from torch.utils.data.datapipes._decorator import functional_datapipe
66
from torch.utils.data.datapipes.datapipe import IterDataPipe
@@ -13,7 +13,7 @@
1313

1414

1515
@functional_datapipe("open_files")
16-
class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
16+
class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]):
1717
r"""
1818
Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
1919

torch/utils/data/datapipes/iter/routeddecoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from io import BufferedIOBase
2-
from typing import Any, Callable, Iterable, Iterator, Sized, Tuple
2+
from typing import Any, Callable, Iterable, Iterator, Sized
33

44
from torch.utils.data.datapipes._decorator import functional_datapipe
55
from torch.utils.data.datapipes.datapipe import IterDataPipe
@@ -16,7 +16,7 @@
1616

1717

1818
@functional_datapipe("routed_decode")
19-
class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
19+
class RoutedDecoderIterDataPipe(IterDataPipe[tuple[str, Any]]):
2020
r"""
2121
Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
2222
@@ -38,12 +38,12 @@ class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
3838

3939
def __init__(
4040
self,
41-
datapipe: Iterable[Tuple[str, BufferedIOBase]],
41+
datapipe: Iterable[tuple[str, BufferedIOBase]],
4242
*handlers: Callable,
4343
key_fn: Callable = extension_extract_fn,
4444
) -> None:
4545
super().__init__()
46-
self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
46+
self.datapipe: Iterable[tuple[str, BufferedIOBase]] = datapipe
4747
if not handlers:
4848
handlers = (decoder_basichandlers, decoder_imagehandler("torch"))
4949
self.decoder = Decoder(*handlers, key_fn=key_fn)
@@ -57,7 +57,7 @@ def __init__(
5757
def add_handler(self, *handler: Callable) -> None:
5858
self.decoder.add_handler(*handler)
5959

60-
def __iter__(self) -> Iterator[Tuple[str, Any]]:
60+
def __iter__(self) -> Iterator[tuple[str, Any]]:
6161
for data in self.datapipe:
6262
pathname = data[0]
6363
result = self.decoder(data)

torch/utils/data/datapipes/iter/selecting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# mypy: allow-untyped-defs
2-
from typing import Callable, Iterator, Tuple, TypeVar
2+
from typing import Callable, Iterator, TypeVar
33

44
from torch.utils.data.datapipes._decorator import functional_datapipe
55
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
@@ -78,7 +78,7 @@ def __iter__(self) -> Iterator[_T_co]:
7878
else:
7979
StreamWrapper.close_streams(data)
8080

81-
def _returnIfTrue(self, data: _T) -> Tuple[bool, _T]:
81+
def _returnIfTrue(self, data: _T) -> tuple[bool, _T]:
8282
condition = self._apply_filter_fn(data)
8383

8484
if df_wrapper.is_column(condition):

torch/utils/data/datapipes/iter/sharding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# mypy: allow-untyped-defs
22
from enum import IntEnum
3-
from typing import Dict, Sized, Tuple
3+
from typing import Dict, Sized
44

55
from torch.utils.data.datapipes._decorator import functional_datapipe
66
from torch.utils.data.datapipes.datapipe import IterDataPipe
@@ -43,7 +43,7 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
4343
def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
4444
self.source_datapipe = source_datapipe
4545
self.sharding_group_filter = sharding_group_filter
46-
self.groups: Dict[int, Tuple[int, int]] = {}
46+
self.groups: Dict[int, tuple[int, int]] = {}
4747
self.num_of_instances = 1
4848
self.instance_id = 0
4949
self._update_num_of_instances()

torch/utils/data/datapipes/iter/streamreader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from io import IOBase
2-
from typing import Iterator, Optional, Tuple
2+
from typing import Iterator, Optional
33

44
from torch.utils.data.datapipes._decorator import functional_datapipe
55
from torch.utils.data.datapipes.datapipe import IterDataPipe
@@ -9,7 +9,7 @@
99

1010

1111
@functional_datapipe("read_from_stream")
12-
class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
12+
class StreamReaderIterDataPipe(IterDataPipe[tuple[str, bytes]]):
1313
r"""
1414
Given IO streams and their label names, yield bytes with label name as tuple.
1515
@@ -30,12 +30,12 @@ class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
3030
"""
3131

3232
def __init__(
33-
self, datapipe: IterDataPipe[Tuple[str, IOBase]], chunk: Optional[int] = None
33+
self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: Optional[int] = None
3434
):
3535
self.datapipe = datapipe
3636
self.chunk = chunk
3737

38-
def __iter__(self) -> Iterator[Tuple[str, bytes]]:
38+
def __iter__(self) -> Iterator[tuple[str, bytes]]:
3939
for furl, stream in self.datapipe:
4040
while True:
4141
d = stream.read(self.chunk)

torch/utils/data/datapipes/map/combining.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# mypy: allow-untyped-defs
2-
from typing import Sized, Tuple, TypeVar
2+
from typing import Sized, TypeVar
33

44
from torch.utils.data.datapipes._decorator import functional_datapipe
55
from torch.utils.data.datapipes.datapipe import MapDataPipe
@@ -34,7 +34,7 @@ class ConcaterMapDataPipe(MapDataPipe):
3434
[0, 1, 2, 0, 1, 2]
3535
"""
3636

37-
datapipes: Tuple[MapDataPipe]
37+
datapipes: tuple[MapDataPipe]
3838

3939
def __init__(self, *datapipes: MapDataPipe):
4040
if len(datapipes) == 0:
@@ -59,7 +59,7 @@ def __len__(self) -> int:
5959

6060

6161
@functional_datapipe("zip")
62-
class ZipperMapDataPipe(MapDataPipe[Tuple[_T_co, ...]]):
62+
class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
6363
r"""
6464
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
6565
@@ -78,7 +78,7 @@ class ZipperMapDataPipe(MapDataPipe[Tuple[_T_co, ...]]):
7878
[(0, 10), (1, 11), (2, 12)]
7979
"""
8080

81-
datapipes: Tuple[MapDataPipe[_T_co], ...]
81+
datapipes: tuple[MapDataPipe[_T_co], ...]
8282

8383
def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None:
8484
if len(datapipes) == 0:
@@ -89,7 +89,7 @@ def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None:
8989
raise TypeError("Expected all inputs to be `Sized`")
9090
self.datapipes = datapipes
9191

92-
def __getitem__(self, index) -> Tuple[_T_co, ...]:
92+
def __getitem__(self, index) -> tuple[_T_co, ...]:
9393
res = []
9494
for dp in self.datapipes:
9595
try:

torch/utils/data/datapipes/utils/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import warnings
77
from io import IOBase
8-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
99

1010
from torch.utils._import_utils import dill_available
1111

@@ -231,7 +231,7 @@ def get_file_binaries_from_pathnames(
231231
yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding))
232232

233233

234-
def validate_pathname_binary_tuple(data: Tuple[str, IOBase]):
234+
def validate_pathname_binary_tuple(data: tuple[str, IOBase]):
235235
if not isinstance(data, tuple):
236236
raise TypeError(
237237
f"pathname binary data should be tuple type, but it is type {type(data)}"

0 commit comments

Comments
 (0)