Skip to content

Commit 547d013

Browse files
Marvin182The TensorFlow Datasets Authors
authored andcommitted
Allow any SupportsIndex for BaseDataSource.__getitem__() for better compatibility with Grain.
PiperOrigin-RevId: 627995734
1 parent 3ec275c commit 547d013

File tree

2 files changed

+14
-23
lines changed

2 files changed

+14
-23
lines changed

tensorflow_datasets/core/data_sources/base.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
from collections.abc import MappingView, Sequence
1919
import dataclasses
2020
import typing
21-
from typing import Any, Generic, Iterable, Protocol, TypeVar
21+
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar
2222

23-
from absl import logging
2423
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2524
from tensorflow_datasets.core import decode
2625
from tensorflow_datasets.core import splits as splits_lib
@@ -38,7 +37,7 @@ class DataSource(Protocol, Generic[T]):
3837
def __len__(self) -> int:
3938
"""Returns the total number of records in the data source."""
4039

41-
def __getitem__(self, key: int) -> T:
40+
def __getitem__(self, key: SupportsIndex) -> T:
4241
"""Returns the value for the given `key`."""
4342

4443
def __getitems__(self, keys: Iterable[int]) -> T:
@@ -76,43 +75,35 @@ class BaseDataSource(MappingView, Sequence):
7675
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
7776
data_source: DataSource[Any] = dataclasses.field(init=False)
7877

79-
def __getitem__(self, record_key: int) -> Any:
80-
if isinstance(record_key, Sequence):
81-
logging.error(
82-
'Calling DataSource.__getitem__() with sequence '
83-
'of record keys (%s) is deprecated. Either pass a single '
84-
'integer or switch to __getitems__().',
85-
record_key,
86-
)
87-
return self.__getitems__(record_key)
88-
record = self.data_source[record_key]
78+
def __getitem__(self, key: SupportsIndex) -> Any:
79+
record = self.data_source[key.__index__()]
8980
return self.dataset_info.features.deserialize_example_np(
9081
record, decoders=self.decoders
9182
)
9283

93-
def __getitems__(self, record_keys: Sequence[int]) -> Sequence[Any]:
84+
def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
9485
"""Retrieves items by batch.
9586
9687
This method allows PyTorch to load records by batch, rather than one by one.
9788
9889
Args:
99-
record_keys: a sequence of keys.
90+
keys: a sequence of keys.
10091
10192
Returns:
10293
The records associated with the keys.
10394
10495
Raises:
10596
IndexError: If the number of retrieved records is incorrect.
10697
"""
107-
if not record_keys:
98+
if not keys:
10899
return []
109-
records = self.data_source.__getitems__(record_keys)
100+
records = self.data_source.__getitems__(keys)
110101
features = self.dataset_info.features
111-
if len(record_keys) != len(records):
102+
if len(keys) != len(records):
112103
raise IndexError(
113-
f'Requested {len(record_keys)} records but got'
104+
f'Requested {len(keys)} records but got'
114105
f' {len(records)} records.'
115-
f'{record_keys=}, {records=}'
106+
f'{keys=}, {records=}'
116107
)
117108
return [
118109
features.deserialize_example_np(record, decoders=self.decoders)

tensorflow_datasets/core/data_sources/base_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def test_read_write(
8686
assert len(data_source) == 3
8787
for i in range(3):
8888
assert data_source[i] == {'id': i}
89-
assert data_source[[0, 2]] == [{'id': 0}, {'id': 2}]
90-
assert data_source[range(0, 2)] == [{'id': 0}, {'id': 1}]
91-
assert data_source[[]] == [] # pylint: disable=g-explicit-bool-comparison
89+
assert data_source.__getitems__([0, 2]) == [{'id': 0}, {'id': 2}]
90+
assert data_source.__getitems__(range(0, 2)) == [{'id': 0}, {'id': 1}]
91+
assert data_source.__getitems__([]) == [] # pylint: disable=g-explicit-bool-comparison
9292
for i, element in enumerate(data_source):
9393
assert element == {'id': i}
9494

0 commit comments

Comments
 (0)