|
25 | 25 | from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
|
26 | 26 |
|
27 | 27 |
|
| 28 | +DUMMY_ENTRIES = entries = [ |
| 29 | + {"index": i, "text": f"Dummy example {i}"} for i in range(2) |
| 30 | +] |
| 31 | +DUMMY_ENTRIES_WITH_NONE_VALUES = [ |
| 32 | + {"index": 0, "text": "Dummy example 0"}, |
| 33 | + {"index": 1, "text": None}, |
| 34 | +] |
| 35 | +DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES = [ |
| 36 | + {"index": 0, "text": "Dummy example 0"}, |
| 37 | + {"index": 1, "text": ""}, |
| 38 | +] |
| 39 | + |
| 40 | + |
28 | 41 | @pytest.mark.parametrize(
|
29 | 42 | ["field", "feature_type", "int_dtype", "float_dtype"],
|
30 | 43 | [
|
@@ -104,43 +117,72 @@ def test_complex_datatype_converter(field, feature_type):
|
104 | 117 | assert isinstance(actual_feature, feature_type)
|
105 | 118 |
|
106 | 119 |
|
107 |
| -class CroissantBuilderTest(testing.TestCase): |
| 120 | +@pytest.fixture(name="crs_builder") |
| 121 | +def mock_croissant_dataset_builder(tmp_path, request): |
| 122 | + dataset_name = request.param["dataset_name"] |
| 123 | + with testing.dummy_croissant_file( |
| 124 | + dataset_name=dataset_name, |
| 125 | + entries=request.param["entries"], |
| 126 | + ) as croissant_file: |
| 127 | + builder = croissant_builder.CroissantBuilder( |
| 128 | + jsonld=croissant_file, |
| 129 | + file_format=FileFormat.ARRAY_RECORD, |
| 130 | + disable_shuffling=True, |
| 131 | + data_dir=tmp_path, |
| 132 | + ) |
| 133 | + yield builder |
108 | 134 |
|
109 |
| - @classmethod |
110 |
| - def setUpClass(cls): |
111 |
| - super(CroissantBuilderTest, cls).setUpClass() |
112 | 135 |
|
113 |
| - with testing.dummy_croissant_file() as croissant_file: |
114 |
| - cls._tfds_tmp_dir = testing.make_tmp_dir() |
115 |
| - cls.builder = croissant_builder.CroissantBuilder( |
116 |
| - jsonld=croissant_file, |
117 |
| - file_format=FileFormat.ARRAY_RECORD, |
118 |
| - disable_shuffling=True, |
119 |
| - data_dir=cls._tfds_tmp_dir, |
120 |
| - ) |
121 |
| - cls.builder.download_and_prepare() |
| 136 | +@pytest.mark.parametrize( |
| 137 | + "crs_builder", |
| 138 | + [ |
| 139 | + {"dataset_name": "DummyDataset", "entries": DUMMY_ENTRIES}, |
| 140 | + { |
| 141 | + "dataset_name": "DummyDatasetWithNoneValues", |
| 142 | + "entries": DUMMY_ENTRIES_WITH_NONE_VALUES, |
| 143 | + }, |
| 144 | + ], |
| 145 | + indirect=True, |
| 146 | +) |
| 147 | +def test_croissant_builder(crs_builder): |
| 148 | + assert crs_builder.version == "1.2.0" |
| 149 | + assert ( |
| 150 | + crs_builder._info().citation |
| 151 | + == "@article{dummyarticle, title={title}, author={author}, year={2020}}" |
| 152 | + ) |
| 153 | + assert crs_builder._info().description == "Dummy description." |
| 154 | + assert crs_builder._info().homepage == "https://dummy_url" |
| 155 | + assert crs_builder._info().redistribution_info.license == "Public" |
| 156 | + assert len(crs_builder.metadata.record_sets) == 1 |
| 157 | + assert crs_builder.metadata.record_sets[0].id == "jsonl" |
| 158 | + assert ( |
| 159 | + crs_builder.metadata.ctx.conforms_to.value |
| 160 | + == "http://mlcommons.org/croissant/1.0" |
| 161 | + ) |
122 | 162 |
|
123 |
| - def test_dataset_info(self): |
124 |
| - assert self.builder.name == "dummydataset" |
125 |
| - assert self.builder.version == "1.2.0" |
126 |
| - assert ( |
127 |
| - self.builder._info().citation |
128 |
| - == "@article{dummyarticle, title={title}, author={author}, year={2020}}" |
129 |
| - ) |
130 |
| - assert self.builder._info().description == "Dummy description." |
131 |
| - assert self.builder._info().homepage == "https://dummy_url" |
132 |
| - assert self.builder._info().redistribution_info.license == "Public" |
133 |
| - assert len(self.builder.metadata.record_sets) == 1 |
134 |
| - assert self.builder.metadata.record_sets[0].id == "jsonl" |
135 |
| - assert ( |
136 |
| - self.builder.metadata.ctx.conforms_to.value |
137 |
| - == "http://mlcommons.org/croissant/1.0" |
138 |
| - ) |
139 | 163 |
|
140 |
| - def test_generated_samples(self): |
141 |
| - for split_name in ["all", "default"]: |
142 |
| - data_source = self.builder.as_data_source(split=split_name) |
143 |
| - assert len(data_source) == 2 |
144 |
| - for i in range(2): |
145 |
| - assert data_source[i]["index"] == i |
146 |
| - assert data_source[i]["text"].decode() == f"Dummy example {i}" |
| 164 | +@pytest.mark.parametrize( |
| 165 | + "crs_builder,expected_entries", |
| 166 | + [ |
| 167 | + ( |
| 168 | + {"dataset_name": "DummyDataset", "entries": DUMMY_ENTRIES}, |
| 169 | + DUMMY_ENTRIES, |
| 170 | + ), |
| 171 | + ( |
| 172 | + { |
| 173 | + "dataset_name": "DummyDatasetWithNoneValues", |
| 174 | + "entries": DUMMY_ENTRIES_WITH_NONE_VALUES, |
| 175 | + }, |
| 176 | + DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES, |
| 177 | + ), |
| 178 | + ], |
| 179 | + indirect=["crs_builder"], |
| 180 | +) |
| 181 | +@pytest.mark.parametrize("split_name", ["all", "default"]) |
| 182 | +def test_download_and_prepare(crs_builder, expected_entries, split_name): |
| 183 | + crs_builder.download_and_prepare() |
| 184 | + data_source = crs_builder.as_data_source(split=split_name) |
| 185 | + assert len(data_source) == 2 |
| 186 | + for i in range(2): |
| 187 | + assert data_source[i]["index"] == expected_entries[i]["index"] |
| 188 | + assert data_source[i]["text"].decode() == expected_entries[i]["text"] |
0 commit comments