34
34
```
35
35
"""
36
36
37
+ from __future__ import annotations
38
+
37
39
from collections .abc import Mapping
38
40
from typing import Any , Dict , Optional , Sequence
39
41
53
55
from tensorflow_datasets .core .utils import croissant_utils
54
56
from tensorflow_datasets .core .utils import type_utils
55
57
from tensorflow_datasets .core .utils import version as version_utils
58
+ from tensorflow_datasets .core .utils .lazy_imports_utils import apache_beam as beam
56
59
from tensorflow_datasets .core .utils .lazy_imports_utils import mlcroissant as mlc
57
60
from tensorflow_datasets .core .utils .lazy_imports_utils import pandas as pd
58
61
@@ -142,7 +145,7 @@ class CroissantBuilder(
142
145
def __init__ (
143
146
self ,
144
147
* ,
145
- jsonld : epath .PathLike ,
148
+ jsonld : epath .PathLike | Mapping [ str , Any ] ,
146
149
record_set_ids : Sequence [str ] | None = None ,
147
150
disable_shuffling : bool | None = False ,
148
151
int_dtype : type_utils .TfdsDType | None = np .int64 ,
@@ -245,7 +248,9 @@ def get_features(self) -> Optional[feature_lib.FeatureConnector]:
245
248
return features_dict .FeaturesDict (features )
246
249
247
250
def _split_generators (
248
- self , dl_manager : download .DownloadManager
251
+ self ,
252
+ dl_manager : download .DownloadManager ,
253
+ pipeline : beam .Pipeline ,
249
254
) -> Dict [splits_lib .Split , split_builder_lib .SplitGenerator ]:
250
255
# If a split recordset is joined for the required record set, we generate
251
256
# splits accordingly. Otherwise, it generates a single `default` split with
@@ -258,37 +263,56 @@ def _split_generators(
258
263
):
259
264
return {
260
265
split ['name' ]: self ._generate_examples (
266
+ pipeline = pipeline ,
261
267
filters = {
262
268
** self ._filters ,
263
269
split_reference .reference_field .id : split ['name' ],
264
- }
270
+ },
265
271
)
266
272
for split in split_reference .split_record_set .data
267
273
}
268
274
else :
269
- return {'default' : self ._generate_examples (filters = self ._filters )}
275
+ return {
276
+ 'default' : self ._generate_examples (
277
+ pipeline = pipeline , filters = self ._filters
278
+ )
279
+ }
270
280
271
281
def _generate_examples (
272
282
self ,
283
+ pipeline : beam .Pipeline ,
273
284
filters : dict [str , Any ],
274
- ) -> split_builder_lib . SplitGenerator :
285
+ ) -> beam . PTransform :
275
286
"""Generates the examples for the given record set.
276
287
277
288
Args:
289
+ pipeline: The Beam pipeline.
278
290
filters: A dict of filters to apply to the records. The keys should be
279
291
field names and the values should be the values to filter by. If a
280
292
record matches all the filters, it will be included in the dataset.
281
293
282
- Yields :
283
- A tuple of (index, record) for each record in the dataset.
294
+ Returns :
295
+ A collection with tuple of (index, record) for each record in the dataset.
284
296
"""
285
297
record_set = croissant_utils .get_record_set (
286
298
self .builder_config .name , metadata = self .metadata
287
299
)
288
300
records = self .dataset .records (record_set .id , filters = filters )
289
- for i , record in enumerate (records ):
290
- # Some samples might not be TFDS-compatible as-is, e.g. from croissant
291
- # describing HuggingFace datasets, so we convert them here. This shouldn't
292
- # impact datasets which are already TFDS-compatible.
293
- record = conversion_utils .to_tfds_value (record , self .info .features )
294
- yield i , record
301
+
302
+ def convert_to_tfds_format (
303
+ global_index : int ,
304
+ record : Any ,
305
+ features : feature_lib .FeatureConnector | None = None ,
306
+ ) -> tuple [int , Any ]:
307
+ if not features :
308
+ raise ValueError ('features should not be None.' )
309
+ return (
310
+ global_index ,
311
+ conversion_utils .to_tfds_value (record , features ),
312
+ )
313
+
314
+ return records .beam_reader (
315
+ pipeline = pipeline
316
+ ) | 'Convert to TFDS format' >> beam .MapTuple (
317
+ convert_to_tfds_format , features = self .info .features
318
+ )
0 commit comments