@@ -149,6 +149,7 @@ def __init__(
149
149
float_dtype : type_utils .TfdsDType | None = np .float32 ,
150
150
mapping : Mapping [str , epath .PathLike ] | None = None ,
151
151
overwrite_version : str | None = None ,
152
+ filters : Mapping [str , Any ] | None = None ,
152
153
** kwargs : Any ,
153
154
):
154
155
"""Initializes a CroissantBuilder.
@@ -170,6 +171,10 @@ def __init__(
170
171
it to `~/Downloads/document.csv`, you can specify
171
172
`mapping={"document.csv": "~/Downloads/document.csv"}`.
172
173
overwrite_version: Semantic version of the dataset to be set.
174
+ filters: A dict of filters to apply to the records at preparation time (in
175
+ the `_generate_examples` function). The keys should be field names and
176
+ the values should be the values to filter by. If a record matches all
177
+ the filters, it will be included in the dataset.
173
178
**kwargs: kwargs to pass to GeneratorBasedBuilder directly.
174
179
"""
175
180
if mapping is None :
@@ -201,6 +206,7 @@ def __init__(
201
206
202
207
self ._int_dtype = int_dtype
203
208
self ._float_dtype = float_dtype
209
+ self ._filters = filters or {}
204
210
205
211
super ().__init__ (
206
212
** kwargs ,
@@ -222,19 +228,11 @@ def _info(self) -> dataset_info.DatasetInfo:
222
228
disable_shuffling = self ._disable_shuffling ,
223
229
)
224
230
225
- def get_record_set (self , record_set_id : str ):
226
- """Returns the desired record set from self.metadata."""
227
- for record_set in self .dataset .metadata .record_sets :
228
- if huggingface_utils .convert_hf_name (record_set .id ) == record_set_id :
229
- return record_set
230
- raise ValueError (
231
- f'Did not find any record set with the name { record_set_id } .'
232
- )
233
-
234
231
def get_features (self ) -> Optional [feature_lib .FeatureConnector ]:
235
232
"""Infers the features dict for the required record set."""
236
- record_set = self .get_record_set (self .builder_config .name )
237
-
233
+ record_set = croissant_utils .get_record_set (
234
+ self .builder_config .name , metadata = self .metadata
235
+ )
238
236
fields = record_set .fields
239
237
features = {}
240
238
for field in fields :
@@ -249,18 +247,53 @@ def get_features(self) -> Optional[feature_lib.FeatureConnector]:
249
247
def _split_generators (
250
248
self , dl_manager : download .DownloadManager
251
249
) -> Dict [splits_lib .Split , split_builder_lib .SplitGenerator ]:
252
- # This will be updated when partitions are implemented in Croissant, ref to:
253
- # https://docs.google.com/document/d/1saz3usja6mk5ugJXNF64_uSXsOzIgbIV28_bu1QamVY
254
- return {'default' : self ._generate_examples ()} # pylint: disable=unreachable
250
+ # If a split recordset is joined for the required record set, we generate
251
+ # splits accordingly. Otherwise, it generates a single `default` split with
252
+ # all the records.
253
+ record_set = croissant_utils .get_record_set (
254
+ self .builder_config .name , metadata = self .metadata
255
+ )
256
+ if split_reference := croissant_utils .get_split_recordset (
257
+ record_set , metadata = self .metadata
258
+ ):
259
+ return {
260
+ split ['name' ]: self ._generate_examples (
261
+ filters = {
262
+ ** self ._filters ,
263
+ split_reference .reference_field .id : split ['name' ].encode (),
264
+ }
265
+ )
266
+ for split in split_reference .split_record_set .data
267
+ }
268
+ else :
269
+ return {'default' : self ._generate_examples (filters = self ._filters )}
255
270
256
271
def _generate_examples (
257
272
self ,
273
+ filters : dict [str , Any ],
258
274
) -> split_builder_lib .SplitGenerator :
259
- record_set = self .get_record_set (self .builder_config .name )
275
+ """Generates the examples for the given record set.
276
+
277
+ Args:
278
+ filters: A dict of filters to apply to the records. The keys should be
279
+ field names and the values should be the values to filter by. If a
280
+ record matches all the filters, it will be included in the dataset.
281
+
282
+ Yields:
283
+ A tuple of (index, record) for each record in the dataset.
284
+ """
285
+ record_set = croissant_utils .get_record_set (
286
+ self .builder_config .name , metadata = self .metadata
287
+ )
260
288
records = self .dataset .records (record_set .id )
261
289
for i , record in enumerate (records ):
262
290
# Some samples might not be TFDS-compatible as-is, e.g. from croissant
263
291
# describing HuggingFace datasets, so we convert them here. This shouldn't
264
292
# impact datasets which are already TFDS-compatible.
265
293
record = huggingface_utils .convert_hf_value (record , self .info .features )
266
- yield i , record
294
+ # After partition implementation, the filters will be applied from
295
+ # mlcroissant `dataset.records` directly.
296
+ # `records = records.filter(f == v for f, v in filters.items())``
297
+ # For now, we apply them in TFDS.
298
+ if all (record [filter ] == value for filter , value in filters .items ()):
299
+ yield i , record
0 commit comments