Skip to content

Commit 1fd44fc

Browse files
author
The TensorFlow Datasets Authors
committed
Add an option to skip cr:Split record sets when getting record_sets_ids to a CroissantBuilder.
PiperOrigin-RevId: 647237046
1 parent b9f906d commit 1fd44fc

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def __init__(
158158
a URL.
159159
record_set_ids: The @ids of the record sets for the dataset. Each record
160160
set will correspond to a separate config. If not specified, a config
161-
will be generated for each record set defined in the Croissant JSON-LD.
161+
will be generated for each record set defined in the Croissant JSON-LD,
162+
except for the record sets which specify `cr:data`.
162163
disable_shuffling: Specify whether to shuffle the examples.
163164
int_dtype: The dtype to use for TFDS integer features. Defaults to
164165
np.int64.
@@ -186,9 +187,7 @@ def __init__(
186187
self.RELEASE_NOTES = {} # pylint: disable=invalid-name
187188

188189
if not record_set_ids:
189-
record_set_ids = [
190-
record_set.id for record_set in self.metadata.record_sets
191-
]
190+
record_set_ids = croissant_utils.get_record_set_ids(self.metadata)
192191
config_names = [
193192
huggingface_utils.convert_hf_name(record_set)
194193
for record_set in record_set_ids

tensorflow_datasets/core/utils/croissant_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,20 @@ def get_tfds_dataset_name(dataset: mlc.Dataset) -> str:
3939
"""Returns TFDS compatible dataset name of the given MLcroissant dataset."""
4040
dataset_name = get_dataset_name(dataset)
4141
return huggingface_utils.convert_hf_name(dataset_name)
42+
43+
44+
def get_record_set_ids(metadata: mlc.Metadata) -> typing.Sequence[str]:
45+
"""Returns record set ids of the given MLcroissant metadata.
46+
47+
Record sets which have the attribute `cr:Data` are excluded (e.g. splits that
48+
specify split or labels mappings).
49+
50+
Args:
51+
metadata: The metadata of the dataset.
52+
"""
53+
record_set_ids = []
54+
for record_set in metadata.record_sets:
55+
if record_set.data is not None:
56+
continue
57+
record_set_ids.append(record_set.id)
58+
return record_set_ids

tensorflow_datasets/core/utils/croissant_utils_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,24 @@ def test_get_tfds_dataset_name(croissant_name, croissant_url, tfds_name):
3434
metadata = mlc.Metadata(name=croissant_name, url=croissant_url)
3535
dataset = mlc.Dataset.from_metadata(metadata)
3636
assert croissant_utils.get_tfds_dataset_name(dataset) == tfds_name
37+
38+
39+
def test_get_record_set_ids():
40+
metadata = mlc.Metadata(
41+
name='dummy_dataset',
42+
url='https://dummy_url',
43+
record_sets=[
44+
mlc.RecordSet(
45+
id='record_set_1',
46+
fields=[],
47+
),
48+
mlc.RecordSet(
49+
id='record_set_2',
50+
data_types=['http://mlcommons.org/croissant/Split'],
51+
fields=[mlc.Field(name='name', data_types=mlc.DataType.TEXT)],
52+
data=[{'name': 'train'}, {'name': 'test'}],
53+
),
54+
],
55+
)
56+
record_set_ids = croissant_utils.get_record_set_ids(metadata=metadata)
57+
assert record_set_ids == ['record_set_1']

0 commit comments

Comments
 (0)