Skip to content

Commit b31a385

Browse files
author
The TensorFlow Datasets Authors
committed
Add support for splits in Croissant TFDS builder.
PiperOrigin-RevId: 652936333
1 parent a7347c0 commit b31a385

File tree

3 files changed

+154
-17
lines changed

3 files changed

+154
-17
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
float_dtype: type_utils.TfdsDType | None = np.float32,
150150
mapping: Mapping[str, epath.PathLike] | None = None,
151151
overwrite_version: str | None = None,
152+
filters: Mapping[str, Any] | None = None,
152153
**kwargs: Any,
153154
):
154155
"""Initializes a CroissantBuilder.
@@ -170,6 +171,10 @@ def __init__(
170171
it to `~/Downloads/document.csv`, you can specify
171172
`mapping={"document.csv": "~/Downloads/document.csv"}`.
172173
overwrite_version: Semantic version of the dataset to be set.
174+
filters: A dict of filters to apply to the records at preparation time (in
175+
the `_generate_examples` function). The keys should be field names and
176+
the values should be the values to filter by. If a record matches all
177+
the filters, it will be included in the dataset.
173178
**kwargs: kwargs to pass to GeneratorBasedBuilder directly.
174179
"""
175180
if mapping is None:
@@ -201,6 +206,7 @@ def __init__(
201206

202207
self._int_dtype = int_dtype
203208
self._float_dtype = float_dtype
209+
self._filters = filters or {}
204210

205211
super().__init__(
206212
**kwargs,
@@ -222,19 +228,11 @@ def _info(self) -> dataset_info.DatasetInfo:
222228
disable_shuffling=self._disable_shuffling,
223229
)
224230

225-
def get_record_set(self, record_set_id: str):
226-
"""Returns the desired record set from self.metadata."""
227-
for record_set in self.dataset.metadata.record_sets:
228-
if huggingface_utils.convert_hf_name(record_set.id) == record_set_id:
229-
return record_set
230-
raise ValueError(
231-
f'Did not find any record set with the name {record_set_id}.'
232-
)
233-
234231
def get_features(self) -> Optional[feature_lib.FeatureConnector]:
235232
"""Infers the features dict for the required record set."""
236-
record_set = self.get_record_set(self.builder_config.name)
237-
233+
record_set = croissant_utils.get_record_set(
234+
self.builder_config.name, metadata=self.metadata
235+
)
238236
fields = record_set.fields
239237
features = {}
240238
for field in fields:
@@ -249,18 +247,53 @@ def get_features(self) -> Optional[feature_lib.FeatureConnector]:
249247
def _split_generators(
250248
self, dl_manager: download.DownloadManager
251249
) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
252-
# This will be updated when partitions are implemented in Croissant, ref to:
253-
# https://docs.google.com/document/d/1saz3usja6mk5ugJXNF64_uSXsOzIgbIV28_bu1QamVY
254-
return {'default': self._generate_examples()} # pylint: disable=unreachable
250+
# If a split recordset is joined for the required record set, we generate
251+
# splits accordingly. Otherwise, it generates a single `default` split with
252+
# all the records.
253+
record_set = croissant_utils.get_record_set(
254+
self.builder_config.name, metadata=self.metadata
255+
)
256+
if split_reference := croissant_utils.get_split_recordset(
257+
record_set, metadata=self.metadata
258+
):
259+
return {
260+
split['name']: self._generate_examples(
261+
filters={
262+
**self._filters,
263+
split_reference.reference_field.id: split['name'].encode(),
264+
}
265+
)
266+
for split in split_reference.split_record_set.data
267+
}
268+
else:
269+
return {'default': self._generate_examples(filters=self._filters)}
255270

256271
def _generate_examples(
257272
self,
273+
filters: dict[str, Any],
258274
) -> split_builder_lib.SplitGenerator:
259-
record_set = self.get_record_set(self.builder_config.name)
275+
"""Generates the examples for the given record set.
276+
277+
Args:
278+
filters: A dict of filters to apply to the records. The keys should be
279+
field names and the values should be the values to filter by. If a
280+
record matches all the filters, it will be included in the dataset.
281+
282+
Yields:
283+
A tuple of (index, record) for each record in the dataset.
284+
"""
285+
record_set = croissant_utils.get_record_set(
286+
self.builder_config.name, metadata=self.metadata
287+
)
260288
records = self.dataset.records(record_set.id)
261289
for i, record in enumerate(records):
262290
# Some samples might not be TFDS-compatible as-is, e.g. from croissant
263291
# describing HuggingFace datasets, so we convert them here. This shouldn't
264292
# impact datasets which are already TFDS-compatible.
265293
record = huggingface_utils.convert_hf_value(record, self.info.features)
266-
yield i, record
294+
# After partition implementation, the filters will be applied from
295+
# mlcroissant `dataset.records` directly.
296+
# `records = records.filter(f == v for f, v in filters.items())``
297+
# For now, we apply them in TFDS.
298+
if all(record[filter] == value for filter, value in filters.items()):
299+
yield i, record

tensorflow_datasets/core/utils/croissant_utils.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,27 @@
1717

1818
from __future__ import annotations
1919

20+
import dataclasses
2021
import typing
2122

2223
from tensorflow_datasets.core.utils import huggingface_utils
24+
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
2325

2426
if typing.TYPE_CHECKING:
2527
# pylint: disable=g-bad-import-order
2628
import mlcroissant as mlc
2729

28-
_HUGGINGFACE_URL_PREFIX = 'https://huggingface.co/datasets/'
30+
_HUGGINGFACE_URL_PREFIX = "https://huggingface.co/datasets/"
31+
32+
33+
@dataclasses.dataclass(frozen=True)
34+
class SplitReference:
35+
"""Information about a split reference in a Croissant dataset."""
36+
37+
# A split record set in a Croissant dataset.
38+
split_record_set: mlc.RecordSet
39+
# A field from another record set that references split_record_set.
40+
reference_field: mlc.Field
2941

3042

3143
def get_dataset_name(dataset: mlc.Dataset) -> str:
@@ -41,6 +53,66 @@ def get_tfds_dataset_name(dataset: mlc.Dataset) -> str:
4153
return huggingface_utils.convert_hf_name(dataset_name)
4254

4355

56+
def get_record_set(record_set_id: str, metadata: mlc.Metadata) -> mlc.RecordSet:
57+
"""Returns the desired record set from a dataset's metadata."""
58+
for record_set in metadata.record_sets:
59+
if huggingface_utils.convert_hf_name(record_set.id) == record_set_id:
60+
return record_set
61+
raise ValueError(
62+
f"Did not find any record set with the name {record_set_id}."
63+
)
64+
65+
66+
def get_field(field_id: str, metadata: mlc.Metadata) -> mlc.Field:
67+
"""Returns the desired field from a dataset's metadata."""
68+
for record_set in metadata.record_sets:
69+
for field in record_set.fields:
70+
if field.id == field_id:
71+
return field
72+
raise ValueError(f"Did not find any field with the name {field_id}.")
73+
74+
75+
def get_record_set_for_field(
76+
field_id: str, metadata: mlc.Metadata
77+
) -> mlc.RecordSet:
78+
"""Given a field id, returns the record set it belongs to, if any."""
79+
for record_set in metadata.record_sets:
80+
for field in record_set.fields:
81+
if field.id == field_id:
82+
return record_set
83+
raise ValueError(f"Did not find any record set with field {field_id}.")
84+
85+
86+
def get_split_recordset(
87+
record_set: mlc.RecordSet, metadata: mlc.Metadata
88+
) -> SplitReference | None:
89+
"""If a given recordset references a split recordset, returns it.
90+
91+
Args:
92+
record_set: The record set to check.
93+
metadata: The metadata of the dataset.
94+
95+
Returns:
96+
If found, a tuple containing: (the field referencing the split record set,
97+
and the split record set), None otherwise.
98+
"""
99+
for field in record_set.fields:
100+
if field.references and field.references.field:
101+
# Check that the referenced record set is of type `cr:Split`.
102+
referenced_field = get_field(field.references.field, metadata)
103+
record_sets = [
104+
node
105+
for node in referenced_field.predecessors
106+
if isinstance(node, mlc.RecordSet)
107+
]
108+
if not record_sets:
109+
raise ValueError("field {field.id} has no RecordSet")
110+
referenced_record_set = record_sets[0]
111+
if str(mlc.DataType.SPLIT) in referenced_record_set.data_types:
112+
return SplitReference(referenced_record_set, field)
113+
return None
114+
115+
44116
def get_record_set_ids(metadata: mlc.Metadata) -> list[str]:
45117
"""Returns record set ids of the given MLcroissant metadata.
46118

tensorflow_datasets/core/utils/croissant_utils_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,35 @@ def test_get_record_set_ids():
5555
)
5656
record_set_ids = croissant_utils.get_record_set_ids(metadata=metadata)
5757
assert record_set_ids == ['record_set_1']
58+
59+
60+
def test_get_split_recordset_with_no_split_recordset():
61+
record_sets = [
62+
mlc.RecordSet(
63+
id='labels',
64+
key='name',
65+
fields=[
66+
mlc.Field(
67+
id='labels/label',
68+
name='label',
69+
data_types=mlc.DataType.TEXT,
70+
)
71+
],
72+
data=[{'label': 'bird'}, {'label': 'bike'}],
73+
),
74+
mlc.RecordSet(
75+
id='samples',
76+
fields=[
77+
mlc.Field(
78+
id='samples/label',
79+
data_types=mlc.DataType.TEXT,
80+
references=mlc.Source(field='labels/label'),
81+
)
82+
],
83+
),
84+
]
85+
metadata = mlc.Metadata(name='dummy', url='dum.my', record_sets=record_sets)
86+
split_recordset = croissant_utils.get_split_recordset(
87+
record_set=metadata.record_sets[0], metadata=metadata
88+
)
89+
assert split_recordset is None

0 commit comments

Comments
 (0)