Skip to content

Commit 697d2b3

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Use mlcroissant's Beam Reader in TFDS in CroissantBuilder.
PiperOrigin-RevId: 671696325
1 parent 1a2a924 commit 697d2b3

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
```
3535
"""
3636

37+
from __future__ import annotations
38+
3739
from collections.abc import Mapping
3840
from typing import Any, Dict, Optional, Sequence
3941

@@ -53,6 +55,7 @@
5355
from tensorflow_datasets.core.utils import croissant_utils
5456
from tensorflow_datasets.core.utils import type_utils
5557
from tensorflow_datasets.core.utils import version as version_utils
58+
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
5659
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
5760
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
5861

@@ -142,7 +145,7 @@ class CroissantBuilder(
142145
def __init__(
143146
self,
144147
*,
145-
jsonld: epath.PathLike,
148+
jsonld: epath.PathLike | Mapping[str, Any],
146149
record_set_ids: Sequence[str] | None = None,
147150
disable_shuffling: bool | None = False,
148151
int_dtype: type_utils.TfdsDType | None = np.int64,
@@ -245,7 +248,9 @@ def get_features(self) -> Optional[feature_lib.FeatureConnector]:
245248
return features_dict.FeaturesDict(features)
246249

247250
def _split_generators(
248-
self, dl_manager: download.DownloadManager
251+
self,
252+
dl_manager: download.DownloadManager,
253+
pipeline: beam.Pipeline,
249254
) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
250255
# If a split recordset is joined for the required record set, we generate
251256
# splits accordingly. Otherwise, it generates a single `default` split with
@@ -258,37 +263,56 @@ def _split_generators(
258263
):
259264
return {
260265
split['name']: self._generate_examples(
266+
pipeline=pipeline,
261267
filters={
262268
**self._filters,
263269
split_reference.reference_field.id: split['name'],
264-
}
270+
},
265271
)
266272
for split in split_reference.split_record_set.data
267273
}
268274
else:
269-
return {'default': self._generate_examples(filters=self._filters)}
275+
return {
276+
'default': self._generate_examples(
277+
pipeline=pipeline, filters=self._filters
278+
)
279+
}
270280

271281
def _generate_examples(
272282
self,
283+
pipeline: beam.Pipeline,
273284
filters: dict[str, Any],
274-
) -> split_builder_lib.SplitGenerator:
285+
) -> beam.PTransform:
275286
"""Generates the examples for the given record set.
276287
277288
Args:
289+
pipeline: The Beam pipeline.
278290
filters: A dict of filters to apply to the records. The keys should be
279291
field names and the values should be the values to filter by. If a
280292
record matches all the filters, it will be included in the dataset.
281293
282-
Yields:
283-
A tuple of (index, record) for each record in the dataset.
294+
Returns:
295+
A collection with tuple of (index, record) for each record in the dataset.
284296
"""
285297
record_set = croissant_utils.get_record_set(
286298
self.builder_config.name, metadata=self.metadata
287299
)
288300
records = self.dataset.records(record_set.id, filters=filters)
289-
for i, record in enumerate(records):
290-
# Some samples might not be TFDS-compatible as-is, e.g. from croissant
291-
# describing HuggingFace datasets, so we convert them here. This shouldn't
292-
# impact datasets which are already TFDS-compatible.
293-
record = conversion_utils.to_tfds_value(record, self.info.features)
294-
yield i, record
301+
302+
def convert_to_tfds_format(
303+
global_index: int,
304+
record: Any,
305+
features: feature_lib.FeatureConnector | None = None,
306+
) -> tuple[int, Any]:
307+
if not features:
308+
raise ValueError('features should not be None.')
309+
return (
310+
global_index,
311+
conversion_utils.to_tfds_value(record, features),
312+
)
313+
314+
return records.beam_reader(
315+
pipeline=pipeline
316+
) | 'Convert to TFDS format' >> beam.MapTuple(
317+
convert_to_tfds_format, features=self.info.features
318+
)

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import numpy as np
1919
import pytest
2020
from tensorflow_datasets import testing
21-
from tensorflow_datasets.core import FileFormat
21+
from tensorflow_datasets.core import file_adapters
2222
from tensorflow_datasets.core.dataset_builders import croissant_builder
2323
from tensorflow_datasets.core.features import image_feature
2424
from tensorflow_datasets.core.features import text_feature
2525
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
2626

27+
FileFormat = file_adapters.FileFormat
28+
2729

2830
DUMMY_ENTRIES = entries = [
2931
{"index": i, "text": f"Dummy example {i}"} for i in range(2)

0 commit comments

Comments
 (0)