Skip to content

Commit 4627fce

Browse files
author
The TensorFlow Datasets Authors
committed
Refactor utils to support both multilingual descriptions and names in CroissantBuilder.
PiperOrigin-RevId: 799927689
1 parent b13d12b commit 4627fce

File tree

4 files changed

+175
-49
lines changed

4 files changed

+175
-49
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,21 @@ def array_datatype_converter(
107107
elif enp.lazy.is_np_dtype(field.data_type):
108108
field_dtype = field.data_type
109109

110+
description = croissant_utils.extract_localized_string(
111+
field.description, field_name='description'
112+
)
113+
110114
if len(field.array_shape_tuple) == 1:
111-
return sequence_feature.Sequence(feature, doc=field.description)
115+
return sequence_feature.Sequence(feature, doc=description)
112116
elif (-1 in field.array_shape_tuple) or (field_dtype is None):
113117
for _ in range(len(field.array_shape_tuple)):
114-
feature = sequence_feature.Sequence(feature, doc=field.description)
118+
feature = sequence_feature.Sequence(feature, doc=description)
115119
return feature
116120
else:
117121
return tensor_feature.Tensor(
118122
shape=field.array_shape_tuple,
119123
dtype=field_dtype,
120-
doc=field.description,
124+
doc=description,
121125
)
122126

123127

@@ -151,6 +155,9 @@ def datatype_converter(
151155
}
152156

153157
field_data_type = field.data_type
158+
description = croissant_utils.extract_localized_string(
159+
field.description, field_name='description'
160+
)
154161

155162
if not field_data_type:
156163
# Fields with sub fields are of type None.
@@ -162,22 +169,22 @@ def datatype_converter(
162169
)
163170
for subfield in field.sub_fields
164171
},
165-
doc=field.description,
172+
doc=description,
166173
)
167174
else:
168175
feature = None
169176
elif field_data_type == bytes:
170-
feature = text_feature.Text(doc=field.description)
177+
feature = text_feature.Text(doc=description)
171178
elif field_data_type in dtype_mapping:
172179
feature = dtype_mapping[field_data_type]
173180
elif enp.lazy.is_np_dtype(field_data_type):
174181
feature = field_data_type
175182
# We return a text feature for date-time features (mlc.DataType.DATE,
176183
# mlc.DataType.DATETIME, and mlc.DataType.TIME).
177184
elif field_data_type == pd.Timestamp or field_data_type == datetime.time:
178-
feature = text_feature.Text(doc=field.description)
185+
feature = text_feature.Text(doc=description)
179186
elif field_data_type == mlc.DataType.IMAGE_OBJECT:
180-
feature = image_feature.Image(doc=field.description)
187+
feature = image_feature.Image(doc=description)
181188
elif field_data_type == mlc.DataType.BOUNDING_BOX:
182189
# TFDS uses REL_YXYX by default, but Hugging Face doesn't enforce a format.
183190
if bbox_format := field.source.format:
@@ -190,14 +197,14 @@ def datatype_converter(
190197
f'{[format.value for format in bb_utils.BBoxFormat]}'
191198
) from e
192199
feature = bounding_boxes.BBoxFeature(
193-
doc=field.description, bbox_format=bbox_format
200+
doc=description, bbox_format=bbox_format
194201
)
195202
elif field_data_type == mlc.DataType.AUDIO_OBJECT:
196203
feature = audio_feature.Audio(
197-
doc=field.description, sample_rate=field.source.sampling_rate
204+
doc=description, sample_rate=field.source.sampling_rate
198205
)
199206
elif field_data_type == mlc.DataType.VIDEO_OBJECT:
200-
feature = video_feature.Video(doc=field.description)
207+
feature = video_feature.Video(doc=description)
201208
else:
202209
raise ValueError(
203210
f'Unknown data type: {field_data_type} for field {field.id}.'

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,12 @@ def test_datatype_converter_complex(
262262
subfield_types: Dict[str, Type[Any]] | None,
263263
):
264264
actual_feature = croissant_builder.datatype_converter(mlc_field)
265-
assert actual_feature.doc.desc == mlc_field.description
265+
expected_description = mlc_field.description
266+
if isinstance(expected_description, dict):
267+
expected_description = expected_description.get(
268+
"en", next(iter(expected_description.values()))
269+
)
270+
assert actual_feature.doc.desc == expected_description
266271
assert isinstance(actual_feature, feature_type)
267272
if subfield_types is not None:
268273
for feature_name in actual_feature.keys():
@@ -271,6 +276,25 @@ def test_datatype_converter_complex(
271276
)
272277

273278

279+
def test_datatype_converter_multilingual_description():
280+
mlc_field = mlc.Field(
281+
data_types=mlc.DataType.TEXT,
282+
description={"en": "English desc", "fr": "Description française"},
283+
)
284+
actual_feature = croissant_builder.datatype_converter(mlc_field)
285+
assert actual_feature.doc.desc == "English desc"
286+
287+
mlc_field_no_en = mlc.Field(
288+
data_types=mlc.DataType.TEXT,
289+
description={
290+
"de": "Deutsche Beschreibung",
291+
"fr": "Description française",
292+
},
293+
)
294+
actual_feature_no_en = croissant_builder.datatype_converter(mlc_field_no_en)
295+
assert actual_feature_no_en.doc.desc == "Deutsche Beschreibung"
296+
297+
274298
def test_datatype_converter_none():
275299
field = mlc.Field(
276300
name="my_field", id="my_field", description="Field with empty data type."

tensorflow_datasets/core/utils/croissant_utils.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,65 @@ def get_croissant_version(version: str | None) -> str | None:
6363
return version
6464

6565

66+
def extract_localized_string(
67+
attribute: str | dict[str, str] | None,
68+
language: str | None = None,
69+
field_name: str = "text field",
70+
) -> str | None:
71+
"""Returns the text in the specified language from a potentially localized object.
72+
73+
Some attributes in Croissant (e.g., `name` and `description`) can be
74+
localized, meaning that they can be either simple strings, or dictionaries
75+
mapping language codes to strings (e.g., `{"en": "English Name", "fr": "Nom
76+
français"}`). This function extracts the text in the specified language from a
77+
potentially localized object.
78+
79+
Args:
80+
attribute: The object containing the text, which can be a simple string, a
81+
dictionary mapping language codes to strings, or None.
82+
language: The desired language code. If None, a heuristic is used: 'en' is
83+
preferred, otherwise the first available language in the dictionary.
84+
field_name: The name of the field being processed (e.g., "name",
85+
"description"), used for error messages.
86+
87+
Returns:
88+
The text string in the desired language, or None if the input is None.
89+
90+
Raises:
91+
ValueError: If the text_object is an empty dictionary, or if the specified
92+
language is not found.
93+
TypeError: If attribute is not a str, dict, or None.
94+
"""
95+
if attribute is None:
96+
return None
97+
if isinstance(attribute, str):
98+
return attribute
99+
100+
if not isinstance(attribute, dict):
101+
raise TypeError(
102+
f"{field_name} must be a string, dictionary, or None. Got"
103+
f" {type(attribute)}"
104+
)
105+
106+
if language is None:
107+
# Try a heuristic language, e.g., 'en'.
108+
if "en" in attribute:
109+
return attribute["en"]
110+
# Otherwise, take the first language in the dict.
111+
try:
112+
first_lang = next(iter(attribute))
113+
return attribute[first_lang]
114+
except StopIteration as exc:
115+
raise ValueError(f"Dataset `{field_name}` dictionary is empty.") from exc
116+
elif language in attribute:
117+
return attribute[language]
118+
else:
119+
raise ValueError(
120+
f"Language '{language}' not found in {field_name} keys:"
121+
f" {list(attribute.keys())}."
122+
)
123+
124+
66125
def get_dataset_name(dataset: mlc.Dataset, language: str | None = None) -> str:
67126
"""Returns dataset name of the given MLcroissant dataset.
68127
@@ -73,26 +132,14 @@ def get_dataset_name(dataset: mlc.Dataset, language: str | None = None) -> str:
73132
"""
74133
if (url := dataset.metadata.url) and url.startswith(_HUGGINGFACE_URL_PREFIX):
75134
return url.removeprefix(_HUGGINGFACE_URL_PREFIX)
76-
name = dataset.metadata.name
77-
if isinstance(name, dict):
78-
if language is None:
79-
# Try a heuristic language, e.g., 'en'.
80-
if "en" in name:
81-
return name["en"]
82-
# Otherwise, take the first language in the dict.
83-
try:
84-
first_lang = next(iter(name))
85-
return name[first_lang]
86-
except StopIteration as exc:
87-
raise ValueError("Dataset name dictionary is empty.") from exc
88-
elif language not in dataset.metadata.name:
89-
raise ValueError(
90-
f"Language {language} not found in dataset names {name}."
91-
)
92-
else:
93-
return name[language]
94-
# At this point, name is not a dict anymore.
95-
return typing.cast(str, name)
135+
name = extract_localized_string(
136+
dataset.metadata.name, language=language, field_name="name"
137+
)
138+
if name is None:
139+
# This case should ideally be prevented by mlcroissant's validation
140+
# ensuring metadata.name is not None.
141+
raise ValueError("Dataset name is missing.")
142+
return name
96143

97144

98145
def get_tfds_dataset_name(

tensorflow_datasets/core/utils/croissant_utils_test.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,53 @@ def test_get_tfds_dataset_name(croissant_name, croissant_url, tfds_name):
3838
), f'Expected TFDS name: {tfds_name}'
3939

4040

41+
@pytest.mark.parametrize(
42+
'attribute,language,expected_text',
43+
[
44+
({'en': 'English Text', 'fr': 'Texte Français'}, None, 'English Text'),
45+
(
46+
{'de': 'Deutscher Text', 'fr': 'Texte Français'},
47+
None,
48+
'Deutscher Text',
49+
),
50+
(
51+
{'en': 'English Text', 'fr': 'Texte Français'},
52+
'fr',
53+
'Texte Français',
54+
),
55+
('Simple Text', None, 'Simple Text'),
56+
('Simple Text', 'en', 'Simple Text'),
57+
(None, None, None),
58+
],
59+
)
60+
def test_extract_localized_string(attribute, language, expected_text):
61+
assert (
62+
croissant_utils.extract_localized_string(attribute, language=language)
63+
== expected_text
64+
)
65+
66+
67+
def test_extract_localized_string_raises():
68+
# Language not found.
69+
with pytest.raises(
70+
ValueError,
71+
match=r"Language 'de' not found in text field keys:",
72+
):
73+
croissant_utils.extract_localized_string(
74+
{'en': 'English Text', 'fr': 'Texte Français'}, language='de'
75+
)
76+
77+
# Empty dictionary.
78+
with pytest.raises(
79+
ValueError, match='Dataset `text field` dictionary is empty'
80+
):
81+
croissant_utils.extract_localized_string({}, language=None)
82+
83+
# Incorrect type.
84+
with pytest.raises(TypeError, match='must be a string, dictionary, or None'):
85+
croissant_utils.extract_localized_string(123)
86+
87+
4188
@pytest.mark.parametrize(
4289
'croissant_name,language,expected_name',
4390
[
@@ -61,6 +108,25 @@ def test_get_dataset_name(croissant_name, language, expected_name):
61108
)
62109

63110

111+
def test_get_dataset_name_raises():
112+
ctx = mlc.Context(conforms_to='http://mlcommons.org/croissant/1.1')
113+
# Test language not found in name.
114+
metadata_lang_not_found = mlc.Metadata(
115+
name={'en': 'English Name', 'fr': 'Nom Français'}, ctx=ctx, url=None
116+
)
117+
dataset_lang_not_found = mlc.Dataset.from_metadata(metadata_lang_not_found)
118+
with pytest.raises(
119+
ValueError, match=r"Language 'de' not found in name keys:"
120+
):
121+
croissant_utils.get_dataset_name(dataset_lang_not_found, language='de')
122+
123+
# Test empty dictionary name.
124+
metadata_empty_dict = mlc.Metadata(name={}, ctx=ctx, url=None)
125+
dataset_empty_dict = mlc.Dataset.from_metadata(metadata_empty_dict)
126+
with pytest.raises(ValueError, match='Dataset `name` dictionary is empty.'):
127+
croissant_utils.get_dataset_name(dataset_empty_dict, language=None)
128+
129+
64130
def test_get_dataset_name_url_precedence():
65131
ctx = mlc.Context(conforms_to='http://mlcommons.org/croissant/1.1')
66132
# Test that URL prefix removal works and takes precedence over name.
@@ -94,24 +160,6 @@ def test_get_dataset_name_url_precedence():
94160
assert croissant_utils.get_dataset_name(dataset_other_url) == 'Not Ignored'
95161

96162

97-
def test_get_dataset_multilingual_name_with_language_not_found():
98-
ctx = mlc.Context(conforms_to='http://mlcommons.org/croissant/1.1')
99-
metadata_lang_not_found = mlc.Metadata(
100-
name={'en': 'English Name', 'fr': 'Nom Français'}, ctx=ctx, url=None
101-
)
102-
dataset_lang_not_found = mlc.Dataset.from_metadata(metadata_lang_not_found)
103-
with pytest.raises(ValueError, match='Language de not found'):
104-
croissant_utils.get_dataset_name(dataset_lang_not_found, language='de')
105-
106-
107-
def test_get_dataset_multilingual_name_with_empty_dict():
108-
ctx = mlc.Context(conforms_to='http://mlcommons.org/croissant/1.1')
109-
metadata_empty_dict = mlc.Metadata(name={}, ctx=ctx, url=None)
110-
dataset_empty_dict = mlc.Dataset.from_metadata(metadata_empty_dict)
111-
with pytest.raises(ValueError, match='Dataset name dictionary is empty'):
112-
croissant_utils.get_dataset_name(dataset_empty_dict, language=None)
113-
114-
115163
@pytest.mark.parametrize(
116164
'croissant_version,tfds_version',
117165
[

0 commit comments

Comments
 (0)