Skip to content

Commit e9d777a

Browse files
author
The TensorFlow Datasets Authors
committed
CroissantBuilder: Convert mlcroissant examples into TFDS examples before yielding.
PiperOrigin-RevId: 649876021
1 parent 9b888a4 commit e9d777a

File tree

3 files changed

+86
-37
lines changed

3 files changed

+86
-37
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,4 +259,8 @@ def _generate_examples(
259259
record_set = self.get_record_set(self.builder_config.name)
260260
records = self.dataset.records(record_set.id)
261261
for i, record in enumerate(records):
262+
# Some samples might not be TFDS-compatible as-is, e.g. from croissant
263+
# describing HuggingFace datasets, so we convert them here. This shouldn't
264+
# impact datasets which are already TFDS-compatible.
265+
record = huggingface_utils.convert_hf_value(record, self.info.features)
262266
yield i, record

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 78 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@
2525
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
2626

2727

28+
DUMMY_ENTRIES = entries = [
29+
{"index": i, "text": f"Dummy example {i}"} for i in range(2)
30+
]
31+
DUMMY_ENTRIES_WITH_NONE_VALUES = [
32+
{"index": 0, "text": "Dummy example 0"},
33+
{"index": 1, "text": None},
34+
]
35+
DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES = [
36+
{"index": 0, "text": "Dummy example 0"},
37+
{"index": 1, "text": ""},
38+
]
39+
40+
2841
@pytest.mark.parametrize(
2942
["field", "feature_type", "int_dtype", "float_dtype"],
3043
[
@@ -104,43 +117,72 @@ def test_complex_datatype_converter(field, feature_type):
104117
assert isinstance(actual_feature, feature_type)
105118

106119

107-
class CroissantBuilderTest(testing.TestCase):
120+
@pytest.fixture(name="crs_builder")
121+
def mock_croissant_dataset_builder(tmp_path, request):
122+
dataset_name = request.param["dataset_name"]
123+
with testing.dummy_croissant_file(
124+
dataset_name=dataset_name,
125+
entries=request.param["entries"],
126+
) as croissant_file:
127+
builder = croissant_builder.CroissantBuilder(
128+
jsonld=croissant_file,
129+
file_format=FileFormat.ARRAY_RECORD,
130+
disable_shuffling=True,
131+
data_dir=tmp_path,
132+
)
133+
yield builder
108134

109-
@classmethod
110-
def setUpClass(cls):
111-
super(CroissantBuilderTest, cls).setUpClass()
112135

113-
with testing.dummy_croissant_file() as croissant_file:
114-
cls._tfds_tmp_dir = testing.make_tmp_dir()
115-
cls.builder = croissant_builder.CroissantBuilder(
116-
jsonld=croissant_file,
117-
file_format=FileFormat.ARRAY_RECORD,
118-
disable_shuffling=True,
119-
data_dir=cls._tfds_tmp_dir,
120-
)
121-
cls.builder.download_and_prepare()
136+
@pytest.mark.parametrize(
137+
"crs_builder",
138+
[
139+
{"dataset_name": "DummyDataset", "entries": DUMMY_ENTRIES},
140+
{
141+
"dataset_name": "DummyDatasetWithNoneValues",
142+
"entries": DUMMY_ENTRIES_WITH_NONE_VALUES,
143+
},
144+
],
145+
indirect=True,
146+
)
147+
def test_croissant_builder(crs_builder):
148+
assert crs_builder.version == "1.2.0"
149+
assert (
150+
crs_builder._info().citation
151+
== "@article{dummyarticle, title={title}, author={author}, year={2020}}"
152+
)
153+
assert crs_builder._info().description == "Dummy description."
154+
assert crs_builder._info().homepage == "https://dummy_url"
155+
assert crs_builder._info().redistribution_info.license == "Public"
156+
assert len(crs_builder.metadata.record_sets) == 1
157+
assert crs_builder.metadata.record_sets[0].id == "jsonl"
158+
assert (
159+
crs_builder.metadata.ctx.conforms_to.value
160+
== "http://mlcommons.org/croissant/1.0"
161+
)
122162

123-
def test_dataset_info(self):
124-
assert self.builder.name == "dummydataset"
125-
assert self.builder.version == "1.2.0"
126-
assert (
127-
self.builder._info().citation
128-
== "@article{dummyarticle, title={title}, author={author}, year={2020}}"
129-
)
130-
assert self.builder._info().description == "Dummy description."
131-
assert self.builder._info().homepage == "https://dummy_url"
132-
assert self.builder._info().redistribution_info.license == "Public"
133-
assert len(self.builder.metadata.record_sets) == 1
134-
assert self.builder.metadata.record_sets[0].id == "jsonl"
135-
assert (
136-
self.builder.metadata.ctx.conforms_to.value
137-
== "http://mlcommons.org/croissant/1.0"
138-
)
139163

140-
def test_generated_samples(self):
141-
for split_name in ["all", "default"]:
142-
data_source = self.builder.as_data_source(split=split_name)
143-
assert len(data_source) == 2
144-
for i in range(2):
145-
assert data_source[i]["index"] == i
146-
assert data_source[i]["text"].decode() == f"Dummy example {i}"
164+
@pytest.mark.parametrize(
165+
"crs_builder,expected_entries",
166+
[
167+
(
168+
{"dataset_name": "DummyDataset", "entries": DUMMY_ENTRIES},
169+
DUMMY_ENTRIES,
170+
),
171+
(
172+
{
173+
"dataset_name": "DummyDatasetWithNoneValues",
174+
"entries": DUMMY_ENTRIES_WITH_NONE_VALUES,
175+
},
176+
DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES,
177+
),
178+
],
179+
indirect=["crs_builder"],
180+
)
181+
@pytest.mark.parametrize("split_name", ["all", "default"])
182+
def test_download_and_prepare(crs_builder, expected_entries, split_name):
183+
crs_builder.download_and_prepare()
184+
data_source = crs_builder.as_data_source(split=split_name)
185+
assert len(data_source) == 2
186+
for i in range(2):
187+
assert data_source[i]["index"] == expected_entries[i]["index"]
188+
assert data_source[i]["text"].decode() == expected_entries[i]["text"]

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,13 @@ def mock_huggingface_dataset_builder(
107107
'foo/bar', 'config', other_arg='this is another arg'
108108
)
109109
login_to_hf.assert_called_once_with('SECRET_TOKEN')
110+
yield builder
111+
112+
113+
def test_dataset_info(builder):
110114
assert builder.info.description == 'description'
111115
assert builder.info.citation == 'citation from the hub'
112116
assert builder.info.redistribution_info.license == 'test-license'
113-
yield builder
114117

115118

116119
def test_download_and_prepare(builder):

0 commit comments

Comments
 (0)