18
18
from collections .abc import MappingView , Sequence
19
19
import dataclasses
20
20
import typing
21
- from typing import Any , Generic , Iterable , Protocol , TypeVar
21
+ from typing import Any , Generic , Iterable , Protocol , SupportsIndex , TypeVar
22
22
23
- from absl import logging
24
23
from tensorflow_datasets .core import dataset_info as dataset_info_lib
25
24
from tensorflow_datasets .core import decode
26
25
from tensorflow_datasets .core import splits as splits_lib
@@ -38,7 +37,7 @@ class DataSource(Protocol, Generic[T]):
38
37
def __len__ (self ) -> int :
39
38
"""Returns the total number of records in the data source."""
40
39
41
- def __getitem__ (self , key : int ) -> T :
40
+ def __getitem__ (self , key : SupportsIndex ) -> T :
42
41
"""Returns the value for the given `key`."""
43
42
44
43
def __getitems__ (self , keys : Iterable [int ]) -> T :
@@ -76,43 +75,35 @@ class BaseDataSource(MappingView, Sequence):
76
75
decoders : type_utils .TreeDict [decode .partial_decode .DecoderArg ] | None = None
77
76
data_source : DataSource [Any ] = dataclasses .field (init = False )
78
77
79
- def __getitem__ (self , record_key : int ) -> Any :
80
- if isinstance (record_key , Sequence ):
81
- logging .error (
82
- 'Calling DataSource.__getitem__() with sequence '
83
- 'of record keys (%s) is deprecated. Either pass a single '
84
- 'integer or switch to __getitems__().' ,
85
- record_key ,
86
- )
87
- return self .__getitems__ (record_key )
88
- record = self .data_source [record_key ]
78
+ def __getitem__ (self , key : SupportsIndex ) -> Any :
79
+ record = self .data_source [key .__index__ ()]
89
80
return self .dataset_info .features .deserialize_example_np (
90
81
record , decoders = self .decoders
91
82
)
92
83
93
- def __getitems__ (self , record_keys : Sequence [int ]) -> Sequence [Any ]:
84
+ def __getitems__ (self , keys : Sequence [int ]) -> Sequence [Any ]:
94
85
"""Retrieves items by batch.
95
86
96
87
This method allows PyTorch to load records by batch, rather than one by one.
97
88
98
89
Args:
99
- record_keys : a sequence of keys.
90
+ keys : a sequence of keys.
100
91
101
92
Returns:
102
93
The records associated with the keys.
103
94
104
95
Raises:
105
96
IndexError: If the number of retrieved records is incorrect.
106
97
"""
107
- if not record_keys :
98
+ if not keys :
108
99
return []
109
- records = self .data_source .__getitems__ (record_keys )
100
+ records = self .data_source .__getitems__ (keys )
110
101
features = self .dataset_info .features
111
- if len (record_keys ) != len (records ):
102
+ if len (keys ) != len (records ):
112
103
raise IndexError (
113
- f'Requested { len (record_keys )} records but got'
104
+ f'Requested { len (keys )} records but got'
114
105
f' { len (records )} records.'
115
- f'{ record_keys = } , { records = } '
106
+ f'{ keys = } , { records = } '
116
107
)
117
108
return [
118
109
features .deserialize_example_np (record , decoders = self .decoders )
0 commit comments