Skip to content

Commit 7b7b708

Browse files
author
The TensorFlow Datasets Authors
committed
Add parameters to the dummy_croissant_file() function used in testing.
PiperOrigin-RevId: 649587388
1 parent cd25e16 commit 7b7b708

File tree

2 files changed

+90
-53
lines changed

2 files changed

+90
-53
lines changed

tensorflow_datasets/testing/test_utils.py

Lines changed: 66 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
import dataclasses
2222
import datetime
2323
import functools
24+
import hashlib
2425
import json
2526
import os
2627
import pathlib
2728
import subprocess
2829
import tempfile
29-
from typing import Any, Iterator, Mapping
30+
from typing import Any, Iterator, Mapping, Sequence
3031
from unittest import mock
3132

3233
from etils import epath
@@ -711,74 +712,94 @@ def now(cls, tz=None) -> datetime.datetime:
711712

712713

713714
@contextlib.contextmanager
714-
def dummy_croissant_file() -> Iterator[epath.Path]:
715+
def dummy_croissant_file(
716+
dataset_name: str = 'DummyDataset',
717+
entries: Sequence[dict[str, Any]] | None = None,
718+
raw_data_filename: epath.PathLike = 'raw_data.jsonl',
719+
croissant_filename: epath.PathLike = 'croissant.json',
720+
) -> Iterator[epath.Path]:
715721
"""Yields temporary path to a dummy Croissant file.
716722
717723
The function creates a temporary directory that stores raw data files and the
718724
Croissant JSON-LD.
725+
726+
Args:
727+
dataset_name: The name of the dataset.
728+
entries: A list of dictionaries representing the dataset's entries. Each
729+
dictionary should contain an 'index' and a 'text' key. If None, the
730+
function will create two entries with indices 0 and 1 and dummy text.
731+
raw_data_filename: Filename of the raw data file.
732+
croissant_filename: Filename of the Croissant JSON-LD file.
719733
"""
720-
entries = [{'index': i, 'text': f'Dummy example {i}'} for i in range(2)]
721-
distribution = [
722-
mlc.FileObject(
723-
id='raw_data',
724-
description='File with the data.',
725-
encoding_format='application/jsonlines',
726-
content_url='data/raw_data.jsonl',
727-
sha256=(
728-
'b13bbcd65bb5ec7c0c64cbceb635de3eadda17f3311c5982dc2d5a342ed97690'
734+
if not entries:
735+
entries = [{'index': i, 'text': f'Dummy example {i}'} for i in range(2)]
736+
737+
fields = [
738+
mlc.Field(
739+
name='index',
740+
description='The sample index.',
741+
data_types=mlc.DataType.INTEGER,
742+
source=mlc.Source(
743+
file_object='raw_data',
744+
extract=mlc.Extract(column='index'),
745+
),
746+
),
747+
mlc.Field(
748+
name='text',
749+
description='The dummy sample text.',
750+
data_types=mlc.DataType.TEXT,
751+
source=mlc.Source(
752+
file_object='raw_data',
753+
extract=mlc.Extract(column='text'),
729754
),
730755
),
731756
]
757+
732758
record_sets = [
733759
mlc.RecordSet(
734760
id='jsonl',
735761
description='Dummy record set.',
736-
fields=[
737-
mlc.Field(
738-
name='index',
739-
description='The sample index.',
740-
data_types=mlc.DataType.INTEGER,
741-
source=mlc.Source(
742-
file_object='raw_data',
743-
extract=mlc.Extract(column='index'),
744-
),
745-
),
746-
mlc.Field(
747-
name='text',
748-
description='The dummy sample text.',
749-
data_types=mlc.DataType.TEXT,
750-
source=mlc.Source(
751-
file_object='raw_data',
752-
extract=mlc.Extract(column='text'),
753-
),
754-
),
755-
],
762+
fields=fields,
756763
)
757764
]
758-
dummy_metadata = mlc.Metadata(
759-
name='DummyDataset',
760-
description='Dummy description.',
761-
cite_as=(
762-
'@article{dummyarticle, title={title}, author={author}, year={2020}}'
763-
),
764-
url='https://dummy_url',
765-
distribution=distribution,
766-
record_sets=record_sets,
767-
version='1.2.0',
768-
license='Public',
769-
)
770765

771766
with tempfile.TemporaryDirectory() as tempdir:
772767
tempdir = epath.Path(tempdir)
773768

774769
# Write raw examples to tempdir/data.
775770
raw_data_dir = tempdir / 'data'
776771
raw_data_dir.mkdir()
777-
raw_data_file = raw_data_dir / 'raw_data.jsonl'
772+
raw_data_file = raw_data_dir / raw_data_filename
778773
raw_data_file.write_text('\n'.join(map(json.dumps, entries)))
779774

775+
# Get the actual raw file's hash, set distribution and metadata.
776+
raw_data_file_content = raw_data_file.read_text()
777+
sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest()
778+
distribution = [
779+
mlc.FileObject(
780+
id='raw_data',
781+
description='File with the data.',
782+
encoding_format='application/jsonlines',
783+
content_url=f'data/{raw_data_filename}',
784+
sha256=sha256,
785+
),
786+
]
787+
dummy_metadata = mlc.Metadata(
788+
name=dataset_name,
789+
description='Dummy description.',
790+
cite_as=(
791+
'@article{dummyarticle, title={title}, author={author},'
792+
' year={2020}}'
793+
),
794+
url='https://dummy_url',
795+
distribution=distribution,
796+
record_sets=record_sets,
797+
version='1.2.0',
798+
license='Public',
799+
)
800+
780801
# Write Croissant JSON-LD to tempdir.
781-
croissant_file = tempdir / 'croissant.json'
802+
croissant_file = tempdir / croissant_filename
782803
croissant_file.write_text(json.dumps(dummy_metadata.to_json(), indent=2))
783804

784805
yield croissant_file

tensorflow_datasets/testing/test_utils_test.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,35 @@ def is_lambda(fn):
229229
assert not is_lambda(gcs_utils.gcs_dataset_info_files)
230230

231231

232-
def test_dummy_croissant_file():
233-
with test_utils.dummy_croissant_file() as croissant_file:
232+
@pytest.mark.parametrize(
233+
'entries',
234+
[
235+
[
236+
{'text': 'Dummy example 0', 'index': 0},
237+
{'text': 'Dummy example 1', 'index': 1},
238+
],
239+
[
240+
{'text': 'Dummy example 0', 'index': 0},
241+
{'text': None, 'index': 1},
242+
],
243+
],
244+
)
245+
def test_dummy_croissant_file(entries):
246+
with test_utils.dummy_croissant_file(entries=entries) as croissant_file:
234247
dataset = mlc.Dataset(jsonld=croissant_file)
235248

236249
assert dataset.jsonld == croissant_file
237250
assert dataset.mapping is None
238251
assert dataset.metadata.description == 'Dummy description.'
252+
assert dataset.metadata.url == 'https://dummy_url'
253+
assert dataset.metadata.version == '1.2.0'
254+
239255
assert [record_set.id for record_set in dataset.metadata.record_sets] == [
240256
'jsonl'
241257
]
242-
assert [record for record in dataset.records('jsonl')] == [
243-
{'text': b'Dummy example 0', 'index': 0},
244-
{'text': b'Dummy example 1', 'index': 1},
245-
]
246-
assert dataset.metadata.url == 'https://dummy_url'
247-
assert dataset.metadata.version == '1.2.0'
258+
for i, record in enumerate(dataset.records('jsonl')):
259+
assert record['index'] == entries[i]['index']
260+
if record['text'] is not None:
261+
assert record['text'].decode() == entries[i]['text']
262+
else:
263+
assert record['text'] == entries[i]['text']

0 commit comments

Comments
 (0)