Skip to content

Commit 51c555c

Browse files
author
The TensorFlow Datasets Authors
committed
Add tests to CroissantBuilder.
PiperOrigin-RevId: 786184714
1 parent cc46f9c commit 51c555c

File tree

2 files changed

+118
-26
lines changed

2 files changed

+118
-26
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def datatype_converter(
214214
return feature
215215

216216

217-
def _extract_license(license_: Any) -> str | None:
217+
def _extract_license(license_: Any) -> str:
218218
"""Extracts the full terms of a license as a string.
219219
220220
In case the license is a CreativeWork, we join the name, description and url
@@ -234,12 +234,13 @@ def _extract_license(license_: Any) -> str | None:
234234
fields = [field for field in possible_fields if field]
235235
return '[' + ']['.join(fields) + ']'
236236
raise ValueError(
237-
f'license_ should be mlc.CreativeWork | str. Got {type(license_)}'
237+
'license_ should be mlc.CreativeWork | str. Got'
238+
f' {type(license_)}: {license_}.'
238239
)
239240

240241

241242
def _get_license(metadata: Any) -> str | None:
242-
"""Gets the license from the metadata."""
243+
"""Gets the license from the metadata (if any) else returns None."""
243244
if not isinstance(metadata, mlc.Metadata):
244245
raise ValueError(f'metadata should be mlc.Metadata. Got {type(metadata)}')
245246
licenses = metadata.license

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 114 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
FileFormat = file_adapters.FileFormat
3434

35+
DUMMY_DESCRIPTION = "Dummy description."
36+
3537

3638
DUMMY_ENTRIES = [
3739
{
@@ -51,8 +53,30 @@
5153
]
5254

5355

56+
def _create_mlc_field(
57+
data_types: mlc.DataType | list[mlc.DataType],
58+
description: str,
59+
is_array: bool = False,
60+
array_shape: str | None = None,
61+
repeated: bool = False,
62+
source: mlc.Source | None = None,
63+
sub_fields: list[mlc.Field] | None = None,
64+
) -> mlc.Field:
65+
field = mlc.Field(
66+
data_types=data_types,
67+
description=description,
68+
is_array=is_array,
69+
array_shape=array_shape,
70+
repeated=repeated,
71+
sub_fields=sub_fields,
72+
)
73+
if source is not None:
74+
field.source = source
75+
return field
76+
77+
5478
@pytest.mark.parametrize(
55-
["field", "expected_feature", "int_dtype", "float_dtype"],
79+
["mlc_field", "expected_feature", "int_dtype", "float_dtype"],
5680
[
5781
(
5882
mlc.Field(
@@ -121,18 +145,18 @@
121145
],
122146
)
123147
def test_simple_datatype_converter(
124-
field, expected_feature, int_dtype, float_dtype
148+
mlc_field, expected_feature, int_dtype, float_dtype
125149
):
126150
actual_feature = croissant_builder.datatype_converter(
127-
field,
151+
mlc_field,
128152
int_dtype=int_dtype or np.int64,
129153
float_dtype=float_dtype or np.float32,
130154
)
131155
assert actual_feature == expected_feature
132156

133157

134-
def test_bbox_datatype_converter():
135-
field = mlc.Field(
158+
def test_datatype_converter_bbox():
159+
field = _create_mlc_field(
136160
data_types=mlc.DataType.BOUNDING_BOX,
137161
description="Bounding box feature",
138162
source=mlc.Source(format="XYWH"),
@@ -142,8 +166,8 @@ def test_bbox_datatype_converter():
142166
assert actual_feature.bbox_format == bb_utils.BBoxFormat.XYWH
143167

144168

145-
def test_bbox_datatype_converter_with_invalid_format():
146-
field = mlc.Field(
169+
def test_datatype_converter_bbox_with_invalid_format():
170+
field = _create_mlc_field(
147171
data_types=mlc.DataType.BOUNDING_BOX,
148172
description="Bounding box feature",
149173
source=mlc.Source(format="InvalidFormat"),
@@ -153,7 +177,7 @@ def test_bbox_datatype_converter_with_invalid_format():
153177

154178

155179
@pytest.mark.parametrize(
156-
["field", "feature_type", "subfield_types"],
180+
["mlc_field", "feature_type", "subfield_types"],
157181
[
158182
(
159183
mlc.Field(data_types=mlc.DataType.TEXT, description="Text feature"),
@@ -219,11 +243,11 @@ def test_bbox_datatype_converter_with_invalid_format():
219243
),
220244
],
221245
)
222-
def test_complex_datatype_converter(field, feature_type, subfield_types):
223-
actual_feature = croissant_builder.datatype_converter(field)
224-
assert actual_feature.doc.desc == field.description
246+
def test_datatype_converter_complex(mlc_field, feature_type, subfield_types):
247+
actual_feature = croissant_builder.datatype_converter(mlc_field)
248+
assert actual_feature.doc.desc == mlc_field.description
225249
assert isinstance(actual_feature, feature_type)
226-
if subfield_types:
250+
if subfield_types is not None:
227251
for feature_name in actual_feature.keys():
228252
assert isinstance(
229253
actual_feature[feature_name], subfield_types[feature_name]
@@ -238,67 +262,134 @@ def test_datatype_converter_none():
238262

239263

240264
def test_multidimensional_datatype_converter():
241-
field = mlc.Field(
265+
mlc_field = _create_mlc_field(
242266
data_types=mlc.DataType.TEXT,
243267
description="Text feature",
244268
is_array=True,
245269
array_shape="2,2",
246270
)
247-
actual_feature = croissant_builder.datatype_converter(field)
271+
actual_feature = croissant_builder.datatype_converter(mlc_field)
248272
assert isinstance(actual_feature, tensor_feature.Tensor)
249273
assert actual_feature.shape == (2, 2)
250274
assert actual_feature.dtype == np.str_
251275

252276

253277
def test_multidimensional_datatype_converter_image_object():
254-
field = mlc.Field(
278+
mlc_field = _create_mlc_field(
255279
data_types=mlc.DataType.IMAGE_OBJECT,
256280
description="Text feature",
257281
is_array=True,
258282
array_shape="2,2",
259283
)
260-
actual_feature = croissant_builder.datatype_converter(field)
284+
actual_feature = croissant_builder.datatype_converter(mlc_field)
261285
assert isinstance(actual_feature, sequence_feature.Sequence)
262286
assert isinstance(actual_feature.feature, sequence_feature.Sequence)
263287
assert isinstance(actual_feature.feature.feature, image_feature.Image)
264288

265289

266290
def test_multidimensional_datatype_converter_plain_list():
267-
field = mlc.Field(
291+
mlc_field = _create_mlc_field(
268292
data_types=mlc.DataType.TEXT,
269293
description="Text feature",
270294
is_array=True,
271295
array_shape="-1",
272296
)
273-
actual_feature = croissant_builder.datatype_converter(field)
297+
actual_feature = croissant_builder.datatype_converter(mlc_field)
274298
assert isinstance(actual_feature, sequence_feature.Sequence)
275299
assert isinstance(actual_feature.feature, text_feature.Text)
276300

277301

278302
def test_multidimensional_datatype_converter_unknown_shape():
279-
field = mlc.Field(
303+
mlc_field = _create_mlc_field(
280304
data_types=mlc.DataType.TEXT,
281305
description="Text feature",
282306
is_array=True,
283307
array_shape="-1,2",
284308
)
285-
actual_feature = croissant_builder.datatype_converter(field)
309+
actual_feature = croissant_builder.datatype_converter(mlc_field)
286310
assert isinstance(actual_feature, sequence_feature.Sequence)
287311
assert isinstance(actual_feature.feature, sequence_feature.Sequence)
288312
assert isinstance(actual_feature.feature.feature, text_feature.Text)
289313

290314

291315
def test_sequence_feature_datatype_converter():
292-
field = mlc.Field(
316+
mlc_field = _create_mlc_field(
293317
data_types=mlc.DataType.TEXT,
294318
description="Text feature",
295319
repeated=True,
296320
)
297-
actual_feature = croissant_builder.datatype_converter(field)
321+
actual_feature = croissant_builder.datatype_converter(mlc_field)
298322
assert isinstance(actual_feature, sequence_feature.Sequence)
299323
assert isinstance(actual_feature.feature, text_feature.Text)
300324

301325

326+
@pytest.mark.parametrize(
327+
["license_", "expected_license"],
328+
[
329+
("MIT", "MIT"),
330+
(
331+
mlc.CreativeWork(
332+
name="Creative Commons",
333+
description="Attribution 4.0 International",
334+
url="https://creativecommons.org/licenses/by/4.0/",
335+
),
336+
(
337+
"[Creative Commons][Attribution 4.0"
338+
" International][https://creativecommons.org/licenses/by/4.0/]"
339+
),
340+
),
341+
(
342+
mlc.CreativeWork(
343+
name="Creative Commons",
344+
),
345+
"[Creative Commons]",
346+
),
347+
(
348+
mlc.CreativeWork(
349+
description="Attribution 4.0 International",
350+
),
351+
"[Attribution 4.0 International]",
352+
),
353+
(
354+
mlc.CreativeWork(
355+
url="https://creativecommons.org/licenses/by/4.0/",
356+
),
357+
"[https://creativecommons.org/licenses/by/4.0/]",
358+
),
359+
(
360+
mlc.CreativeWork(),
361+
"[]",
362+
),
363+
],
364+
)
365+
def test_extract_license(license_, expected_license):
366+
actual_license = croissant_builder._extract_license(license_)
367+
assert actual_license == expected_license
368+
369+
370+
def test_extract_license_with_invalid_input():
371+
with pytest.raises(
372+
ValueError, match="^license_ should be mlc.CreativeWork | str"
373+
):
374+
croissant_builder._extract_license(123)
375+
376+
377+
def test_get_license():
378+
metadata = mlc.Metadata(license=["MIT", "Apache 2.0"])
379+
actual_license = croissant_builder._get_license(metadata)
380+
assert actual_license == "MIT, Apache 2.0"
381+
382+
383+
def test_get_license_with_invalid_input():
384+
with pytest.raises(ValueError, match="metadata should be mlc.Metadata"):
385+
croissant_builder._get_license(123)
386+
387+
388+
def test_get_license_with_empty_license():
389+
metadata = mlc.Metadata(license=[])
390+
assert croissant_builder._get_license(metadata) is None
391+
392+
302393
def test_version_converter(tmp_path):
303394
with testing.dummy_croissant_file(version="1.0") as croissant_file:
304395
builder = croissant_builder.CroissantBuilder(
@@ -344,7 +435,7 @@ def test_croissant_builder(crs_builder):
344435
crs_builder._info().citation
345436
== "@article{dummyarticle, title={title}, author={author}, year={2020}}"
346437
)
347-
assert crs_builder._info().description == "Dummy description."
438+
assert crs_builder._info().description == DUMMY_DESCRIPTION
348439
assert crs_builder._info().homepage == "https://dummy_url"
349440
assert crs_builder._info().redistribution_info.license == "Public"
350441
# One `split` and one `jsonl` recordset.

0 commit comments

Comments
 (0)