Skip to content

Commit 7b91b21

Browse files
author
The TensorFlow Datasets Authors
committed
Add splits to the test croissant.
PiperOrigin-RevId: 700635612
1 parent a8c87b5 commit 7b91b21

File tree

2 files changed

+110
-31
lines changed

2 files changed

+110
-31
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,21 @@
3131
FileFormat = file_adapters.FileFormat
3232

3333

34-
DUMMY_ENTRIES = entries = [
35-
{"index": i, "text": f"Dummy example {i}"} for i in range(2)
34+
DUMMY_ENTRIES = [
35+
{
36+
"index": i,
37+
"text": f"Dummy example {i}",
38+
"split": "train" if i == 0 else "test",
39+
}
40+
for i in range(2)
3641
]
3742
DUMMY_ENTRIES_WITH_NONE_VALUES = [
38-
{"index": 0, "text": "Dummy example 0"},
39-
{"index": 1, "text": None},
43+
{"split": "train", "index": 0, "text": "Dummy example 0"},
44+
{"split": "test", "index": 1, "text": None},
4045
]
4146
DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES = [
42-
{"index": 0, "text": "Dummy example 0"},
43-
{"index": 1, "text": ""},
47+
{"split": "train", "index": 0, "text": "Dummy example 0"},
48+
{"split": "test", "index": 1, "text": ""},
4449
]
4550

4651

@@ -173,6 +178,7 @@ def mock_croissant_dataset_builder(tmp_path, request):
173178
with testing.dummy_croissant_file(
174179
dataset_name=dataset_name,
175180
entries=request.param["entries"],
181+
split_names=["train", "test"],
176182
) as croissant_file:
177183
builder = croissant_builder.CroissantBuilder(
178184
jsonld=croissant_file,
@@ -203,8 +209,12 @@ def test_croissant_builder(crs_builder):
203209
assert crs_builder._info().description == "Dummy description."
204210
assert crs_builder._info().homepage == "https://dummy_url"
205211
assert crs_builder._info().redistribution_info.license == "Public"
206-
assert len(crs_builder.metadata.record_sets) == 1
207-
assert crs_builder.metadata.record_sets[0].id == "jsonl"
212+
# One `split` and one `jsonl` recordset.
213+
assert len(crs_builder.metadata.record_sets) == 2
214+
assert set([rs.id for rs in crs_builder.metadata.record_sets]) == {
215+
"jsonl",
216+
"split",
217+
}
208218
assert (
209219
crs_builder.metadata.ctx.conforms_to.value
210220
== "http://mlcommons.org/croissant/1.0"
@@ -228,11 +238,11 @@ def test_croissant_builder(crs_builder):
228238
],
229239
indirect=["crs_builder"],
230240
)
231-
@pytest.mark.parametrize("split_name", ["all", "default"])
241+
@pytest.mark.parametrize("split_name", ["train", "test"])
232242
def test_download_and_prepare(crs_builder, expected_entries, split_name):
233243
crs_builder.download_and_prepare()
234244
data_source = crs_builder.as_data_source(split=split_name)
235-
assert len(data_source) == 2
245+
assert len(data_source) == 1
236246
for entry, expected_entry in zip(data_source, expected_entries):
237247
assert entry["index"] == expected_entry["index"]
238248
assert entry["text"].decode() == expected_entry["text"]

tensorflow_datasets/testing/test_utils.py

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def dummy_croissant_file(
723723
entries: Sequence[dict[str, Any]] | None = None,
724724
raw_data_filename: epath.PathLike = 'raw_data.jsonl',
725725
croissant_filename: epath.PathLike = 'croissant.json',
726+
split_names: Sequence[str] | None = None,
726727
) -> Iterator[epath.Path]:
727728
"""Yields temporary path to a dummy Croissant file.
728729
@@ -732,13 +733,29 @@ def dummy_croissant_file(
732733
Args:
733734
dataset_name: The name of the dataset.
734735
entries: A list of dictionaries representing the dataset's entries. Each
735-
dictionary should contain an 'index' and a 'text' key. If None, the
736-
function will create two entries with indices 0 and 1 and dummy text.
737-
raw_data_filename: Filename of the raw data file.
736+
dictionary should contain an 'index', a 'text', and a `split` key. If
737+
None, the function will create two entries with indices 0 and 1 and dummy
738+
text, and with the first entry belonging to the split `train` and the
739+
second to `test`.
740+
raw_data_filename: Filename of the raw data file. If `split_names` is True,
741+
the function will create a raw data file for each split, including the
742+
split name before the file extension.
738743
croissant_filename: Filename of the Croissant JSON-LD file.
744+
split_names: A list of split names to populate the split record set with. If
745+
split_names are defined, they must match the `split` key in the entries.
746+
If None, the function will create a split record set with the default
747+
split names `train` and `test`. If `split_names` is defined, the `split`
748+
key in the entries must match one of the split names.
739749
"""
740750
if entries is None:
741-
entries = [{'index': i, 'text': f'Dummy example {i}'} for i in range(2)]
751+
entries = [
752+
{
753+
'index': i,
754+
'text': f'Dummy example {i}',
755+
'split': 'train' if i % 2 == 0 else 'test',
756+
}
757+
for i in range(2)
758+
]
742759

743760
fields = [
744761
mlc.Field(
@@ -771,29 +788,82 @@ def dummy_croissant_file(
771788
fields=fields,
772789
)
773790
]
791+
if split_names:
792+
record_sets[0].fields.append(
793+
mlc.Field(
794+
id='jsonl/split',
795+
name='jsonl/split',
796+
description='The dummy split.',
797+
data_types=mlc.DataType.TEXT,
798+
source=mlc.Source(
799+
file_object='raw_data',
800+
extract=mlc.Extract(file_property='fullpath'),
801+
transforms=[mlc.Transform(regex='.*(.+).+jsonl$')],
802+
),
803+
references=mlc.Source(field='split/name'),
804+
),
805+
)
806+
record_sets.append(
807+
mlc.RecordSet(
808+
id='split',
809+
name='split',
810+
key='split/name',
811+
data_types=[mlc.DataType.SPLIT],
812+
description='Dummy split.',
813+
fields=[
814+
mlc.Field(
815+
id='split/name',
816+
name='split/name',
817+
description='The dummy split name.',
818+
data_types=mlc.DataType.TEXT,
819+
)
820+
],
821+
data=[{'split/name': split_name} for split_name in split_names],
822+
)
823+
)
774824

775825
with tempfile.TemporaryDirectory() as tempdir:
776826
tempdir = epath.Path(tempdir)
777827

778828
# Write raw examples to tempdir/data.
779829
raw_data_dir = tempdir / 'data'
780830
raw_data_dir.mkdir()
781-
raw_data_file = raw_data_dir / raw_data_filename
782-
raw_data_file.write_text('\n'.join(map(json.dumps, entries)))
783-
784-
# Get the actual raw file's hash, set distribution and metadata.
785-
raw_data_file_content = raw_data_file.read_text()
786-
sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest()
787-
distribution = [
788-
mlc.FileObject(
789-
id='raw_data',
790-
name='raw_data',
791-
description='File with the data.',
792-
encoding_format='application/jsonlines',
793-
content_url=f'data/{raw_data_filename}',
794-
sha256=sha256,
795-
),
796-
]
831+
if split_names:
832+
parts = str(raw_data_filename).split('.')
833+
file_name, extension = '.'.join(parts[:-1]), parts[-1]
834+
for split_name in split_names:
835+
raw_data_file = raw_data_dir / (
836+
file_name + '_' + split_name + '.' + extension
837+
)
838+
split_entries = [
839+
entry for entry in entries if entry['split'] == split_name
840+
]
841+
raw_data_file.write_text('\n'.join(map(json.dumps, split_entries)))
842+
distribution = [
843+
mlc.FileSet(
844+
id='raw_data',
845+
name='raw_data',
846+
description='Files with the data.',
847+
encoding_format='application/jsonlines',
848+
includes=f'data/{file_name}*.{extension}',
849+
),
850+
]
851+
else:
852+
raw_data_file = raw_data_dir / raw_data_filename
853+
raw_data_file.write_text('\n'.join(map(json.dumps, entries)))
854+
# Get the actual raw file's hash, set distribution and metadata.
855+
raw_data_file_content = raw_data_file.read_text()
856+
sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest()
857+
distribution = [
858+
mlc.FileObject(
859+
id='raw_data',
860+
name='raw_data',
861+
description='File with the data.',
862+
encoding_format='application/jsonlines',
863+
content_url=f'data/{raw_data_filename}',
864+
sha256=sha256,
865+
),
866+
]
797867
dummy_metadata = mlc.Metadata(
798868
name=dataset_name,
799869
description='Dummy description.',
@@ -807,7 +877,6 @@ def dummy_croissant_file(
807877
version='1.2.0',
808878
license='Public',
809879
)
810-
811880
# Write Croissant JSON-LD to tempdir.
812881
croissant_file = tempdir / croissant_filename
813882
croissant_file.write_text(json.dumps(dummy_metadata.to_json(), indent=2))

0 commit comments

Comments
 (0)