@@ -68,18 +68,34 @@ class BaseDataSource(MappingView, Sequence):
68
68
split: The split to load in the data source.
69
69
decoders: Optional decoders for decoding.
70
70
data_source: The underlying data source to initialize in the __post_init__.
71
+ deserialize_method: How to deserialize the bytes that are read before
72
+ returning.
71
73
"""
72
74
73
75
dataset_info : dataset_info_lib .DatasetInfo
74
76
split : splits_lib .Split | None = None
75
77
decoders : type_utils .TreeDict [decode .partial_decode .DecoderArg ] | None = None
76
78
data_source : DataSource [Any ] = dataclasses .field (init = False )
79
+ deserialize_method : decode .DeserializeMethod = (
80
+ decode .DeserializeMethod .DESERIALIZE_AND_DECODE
81
+ )
82
+
83
+ def _deserialize (self , record : Any ) -> Any :
84
+ match self .deserialize_method :
85
+ case decode .DeserializeMethod .RAW_BYTES :
86
+ return record
87
+ case decode .DeserializeMethod .DESERIALIZE_NO_DECODE :
88
+ if file_format := self .dataset_info .file_format :
89
+ return file_format .deserialize (record )
90
+ raise ValueError ('No file format set, cannot deserialize bytes!' )
91
+ case decode .DeserializeMethod .DESERIALIZE_AND_DECODE :
92
+ if features := self .dataset_info .features :
93
+ return features .deserialize_example_np (record , decoders = self .decoders ) # pylint: disable=attribute-error
94
+ raise ValueError ('No features set, cannot decode example!' )
77
95
78
96
def __getitem__ (self , key : SupportsIndex ) -> Any :
79
97
record = self .data_source [key .__index__ ()]
80
- return self .dataset_info .features .deserialize_example_np (
81
- record , decoders = self .decoders
82
- )
98
+ return self ._deserialize (record )
83
99
84
100
def __getitems__ (self , keys : Sequence [int ]) -> Sequence [Any ]:
85
101
"""Retrieves items by batch.
@@ -98,17 +114,12 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
98
114
if not keys :
99
115
return []
100
116
records = self .data_source .__getitems__ (keys )
101
- features = self .dataset_info .features
102
117
if len (keys ) != len (records ):
103
118
raise IndexError (
104
- f'Requested { len (keys )} records but got'
105
- f' { len (records )} records.'
119
+ f'Requested { len (keys )} records but got { len (records )} records.'
106
120
f'{ keys = } , { records = } '
107
121
)
108
- return [
109
- features .deserialize_example_np (record , decoders = self .decoders )
110
- for record in records
111
- ]
122
+ return [self ._deserialize (record ) for record in records ]
112
123
113
124
def __repr__ (self ) -> str :
114
125
decoders_repr = (
0 commit comments