Skip to content

Commit 7689579

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Add a context manager to create a dummy Croissant file.
PiperOrigin-RevId: 640891862
1 parent 7eeb200 commit 7689579

File tree

4 files changed

+110
-87
lines changed

4 files changed

+110
-87
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 11 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515

1616
"""Tests for croissant_builder."""
1717

18-
import json
19-
import tempfile
20-
21-
from etils import epath
2218
import numpy as np
2319
import pytest
2420
from tensorflow_datasets import testing
@@ -28,61 +24,6 @@
2824
from tensorflow_datasets.core.features import text_feature
2925
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
3026

31-
DUMMY_ENTRIES = [{"index": i, "text": f"Dummy example {i}"} for i in range(2)]
32-
33-
34-
def get_dummy_metadata():
35-
distribution = [
36-
mlc.FileObject(
37-
id="raw_data",
38-
description="File with the data.",
39-
encoding_format="application/jsonlines",
40-
content_url="data/raw_data.jsonl",
41-
sha256=(
42-
"ec6a2e5865be2c3ea2bf41817bf9ca78cbfcdd60bce0282721da8625a28fd10d"
43-
),
44-
),
45-
]
46-
record_sets = [
47-
mlc.RecordSet(
48-
id="jsonl",
49-
description="Dummy record set.",
50-
fields=[
51-
mlc.Field(
52-
name="index",
53-
description="The sample index.",
54-
data_types=mlc.DataType.INTEGER,
55-
source=mlc.Source(
56-
file_object="raw_data",
57-
extract=mlc.Extract(column="index"),
58-
),
59-
),
60-
mlc.Field(
61-
name="text",
62-
description="The dummy sample text.",
63-
data_types=mlc.DataType.TEXT,
64-
source=mlc.Source(
65-
file_object="raw_data",
66-
extract=mlc.Extract(column="text"),
67-
),
68-
),
69-
],
70-
)
71-
]
72-
dummy_metadata = mlc.Metadata(
73-
name="DummyDataset",
74-
description="Dummy description.",
75-
cite_as=(
76-
"@article{dummyarticle, title={title}, author={author}, year={2020}}"
77-
),
78-
url="https://dummy_url",
79-
distribution=distribution,
80-
record_sets=record_sets,
81-
version="1.2.0",
82-
license="Public",
83-
)
84-
return dummy_metadata
85-
8627

8728
@pytest.mark.parametrize(
8829
["field", "feature_type", "int_dtype", "float_dtype"],
@@ -169,29 +110,15 @@ class CroissantBuilderTest(testing.TestCase):
169110
def setUpClass(cls):
170111
super(CroissantBuilderTest, cls).setUpClass()
171112

172-
# Write raw examples on tmp/data.
173-
data_dir = epath.Path(tempfile.gettempdir()) / "data"
174-
data_dir.mkdir(parents=True, exist_ok=True)
175-
raw_output_file = data_dir / "raw_data.jsonl"
176-
with open(raw_output_file, "w") as outfile:
177-
for entry in DUMMY_ENTRIES:
178-
json.dump(entry, outfile)
179-
outfile.write("\n")
180-
181-
# Write Croissant JSON-LD on tmp.
182-
dummy_metadata = get_dummy_metadata()
183-
croissant_file = epath.Path(tempfile.gettempdir()) / "croissant.json"
184-
with open(croissant_file, "w") as f:
185-
f.write(json.dumps(dummy_metadata.to_json(), indent=2))
186-
f.write("\n")
187-
188-
cls._tfds_tmp_dir = testing.make_tmp_dir()
189-
cls.builder = croissant_builder.CroissantBuilder(
190-
jsonld=croissant_file,
191-
file_format=FileFormat.ARRAY_RECORD,
192-
disable_shuffling=True,
193-
data_dir=cls._tfds_tmp_dir,
194-
)
113+
with testing.dummy_croissant_file() as croissant_file:
114+
cls._tfds_tmp_dir = testing.make_tmp_dir()
115+
cls.builder = croissant_builder.CroissantBuilder(
116+
jsonld=croissant_file,
117+
file_format=FileFormat.ARRAY_RECORD,
118+
disable_shuffling=True,
119+
data_dir=cls._tfds_tmp_dir,
120+
)
121+
cls.builder.download_and_prepare()
195122

196123
def test_dataset_info(self):
197124
assert self.builder.name == "dummydataset"
@@ -211,10 +138,9 @@ def test_dataset_info(self):
211138
)
212139

213140
def test_generated_samples(self):
214-
self.builder.download_and_prepare()
215141
for split_name in ["all", "default"]:
216142
data_source = self.builder.as_data_source(split=split_name)
217143
assert len(data_source) == 2
218144
for i in range(2):
219-
assert data_source[i]["index"] == DUMMY_ENTRIES[i]["index"]
220-
assert data_source[i]["text"].decode() == DUMMY_ENTRIES[i]["text"]
145+
assert data_source[i]["index"] == i
146+
assert data_source[i]["text"].decode() == f"Dummy example {i}"

tensorflow_datasets/testing/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tensorflow_datasets.testing.test_case_in_context import TestCaseInContext
4040
from tensorflow_datasets.testing.test_utils import assert_features_equal
4141
from tensorflow_datasets.testing.test_utils import disable_gcs_access
42+
from tensorflow_datasets.testing.test_utils import dummy_croissant_file
4243
from tensorflow_datasets.testing.test_utils import DummyBeamDataset
4344
from tensorflow_datasets.testing.test_utils import DummyDataset
4445
from tensorflow_datasets.testing.test_utils import DummyDatasetCollection
@@ -53,9 +54,9 @@
5354
from tensorflow_datasets.testing.test_utils import MockFs
5455
from tensorflow_datasets.testing.test_utils import rm_tmp_dir
5556
from tensorflow_datasets.testing.test_utils import run_in_graph_and_eager_modes
57+
from tensorflow_datasets.testing.test_utils import set_current_datetime
5658
from tensorflow_datasets.testing.test_utils import test_main
5759
from tensorflow_datasets.testing.test_utils import tmp_dir
58-
from tensorflow_datasets.testing.test_utils import set_current_datetime
5960
# LINT.ThenChange(:deps)
6061
# pylint: enable=g-import-not-at-top,g-importing-member
6162

@@ -66,6 +67,7 @@
6667
"tensorflow_datasets.testing.dataset_builder_testing"
6768
),
6869
"disable_gcs_access": "tensorflow_datasets.testing.test_utils",
70+
"dummy_croissant_file": "tensorflow_datasets.testing.test_utils",
6971
"DummyBeamDataset": "tensorflow_datasets.testing.test_utils",
7072
"DummyDataset": "tensorflow_datasets.testing.test_utils",
7173
"DummyDatasetCollection": "tensorflow_datasets.testing.test_utils",
@@ -90,14 +92,14 @@
9092
# TODO(afrozm): rm from here and add as methods to TestCase
9193
"rm_tmp_dir": "tensorflow_datasets.testing.test_utils",
9294
"run_in_graph_and_eager_modes": "tensorflow_datasets.testing.test_utils",
95+
"set_current_datetime": "tensorflow_datasets.testing.test_utils",
9396
"SubTestCase": "tensorflow_datasets.testing.feature_test_case",
9497
"test_main": "tensorflow_datasets.testing.test_utils",
9598
"TestCase": "tensorflow_datasets.testing.test_case",
9699
"TestCaseInContext": "tensorflow_datasets.testing.test_case_in_context",
97100
"TestValue": "tensorflow_datasets.testing.feature_test_case",
98101
# TODO(afrozm): rm from here and add as methods to TestCase
99102
"tmp_dir": "tensorflow_datasets.testing.test_utils",
100-
"set_current_datetime": "tensorflow_datasets.testing.test_utils",
101103
# LINT.ThenChange(:pydeps)
102104
}
103105

tensorflow_datasets/testing/test_utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import dataclasses
2222
import datetime
2323
import functools
24+
import json
2425
import os
2526
import pathlib
2627
import subprocess
@@ -39,6 +40,7 @@
3940
from tensorflow_datasets.core import lazy_imports_lib
4041
from tensorflow_datasets.core import naming
4142
from tensorflow_datasets.core import utils
43+
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
4244
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
4345

4446

@@ -706,3 +708,77 @@ def now(cls, tz=None) -> datetime.datetime:
706708

707709
with mock.patch.object(datetime, 'datetime', new=MockDatetime):
708710
yield
711+
712+
713+
@contextlib.contextmanager
714+
def dummy_croissant_file() -> Iterator[epath.Path]:
715+
"""Yields temporary path to a dummy Croissant file.
716+
717+
The function creates a temporary directory that stores raw data files and the
718+
Croissant JSON-LD.
719+
"""
720+
entries = [{'index': i, 'text': f'Dummy example {i}'} for i in range(2)]
721+
distribution = [
722+
mlc.FileObject(
723+
id='raw_data',
724+
description='File with the data.',
725+
encoding_format='application/jsonlines',
726+
content_url='data/raw_data.jsonl',
727+
sha256=(
728+
'b13bbcd65bb5ec7c0c64cbceb635de3eadda17f3311c5982dc2d5a342ed97690'
729+
),
730+
),
731+
]
732+
record_sets = [
733+
mlc.RecordSet(
734+
id='jsonl',
735+
description='Dummy record set.',
736+
fields=[
737+
mlc.Field(
738+
name='index',
739+
description='The sample index.',
740+
data_types=mlc.DataType.INTEGER,
741+
source=mlc.Source(
742+
file_object='raw_data',
743+
extract=mlc.Extract(column='index'),
744+
),
745+
),
746+
mlc.Field(
747+
name='text',
748+
description='The dummy sample text.',
749+
data_types=mlc.DataType.TEXT,
750+
source=mlc.Source(
751+
file_object='raw_data',
752+
extract=mlc.Extract(column='text'),
753+
),
754+
),
755+
],
756+
)
757+
]
758+
dummy_metadata = mlc.Metadata(
759+
name='DummyDataset',
760+
description='Dummy description.',
761+
cite_as=(
762+
'@article{dummyarticle, title={title}, author={author}, year={2020}}'
763+
),
764+
url='https://dummy_url',
765+
distribution=distribution,
766+
record_sets=record_sets,
767+
version='1.2.0',
768+
license='Public',
769+
)
770+
771+
with tempfile.TemporaryDirectory() as tempdir:
772+
tempdir = epath.Path(tempdir)
773+
774+
# Write raw examples to tempdir/data.
775+
raw_data_dir = tempdir / 'data'
776+
raw_data_dir.mkdir()
777+
raw_data_file = raw_data_dir / 'raw_data.jsonl'
778+
raw_data_file.write_text('\n'.join(map(json.dumps, entries)))
779+
780+
# Write Croissant JSON-LD to tempdir.
781+
croissant_file = tempdir / 'croissant.json'
782+
croissant_file.write_text(json.dumps(dummy_metadata.to_json(), indent=2))
783+
784+
yield croissant_file

tensorflow_datasets/testing/test_utils_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pathlib
1919

20+
import mlcroissant as mlc
2021
import pytest
2122
import tensorflow as tf
2223
from tensorflow_datasets.testing import test_case
@@ -226,3 +227,21 @@ def is_lambda(fn):
226227
assert not is_lambda(gcs_utils.gcs_dataset_info_files)
227228
assert is_lambda(gcs_utils.gcs_dataset_info_files)
228229
assert not is_lambda(gcs_utils.gcs_dataset_info_files)
230+
231+
232+
def test_dummy_croissant_file():
233+
with test_utils.dummy_croissant_file() as croissant_file:
234+
dataset = mlc.Dataset(jsonld=croissant_file)
235+
236+
assert dataset.jsonld == croissant_file
237+
assert dataset.mapping is None
238+
assert dataset.metadata.description == 'Dummy description.'
239+
assert [record_set.id for record_set in dataset.metadata.record_sets] == [
240+
'jsonl'
241+
]
242+
assert [record for record in dataset.records('jsonl')] == [
243+
{'text': b'Dummy example 0', 'index': 0},
244+
{'text': b'Dummy example 1', 'index': 1},
245+
]
246+
assert dataset.metadata.url == 'https://dummy_url'
247+
assert dataset.metadata.version == '1.2.0'

0 commit comments

Comments
 (0)