Skip to content

Commit 89440d4

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Use 1 unified function to convert Hugging Face names.
PiperOrigin-RevId: 615827787
1 parent d9dfe5d commit 89440d4

File tree

5 files changed

+34
-53
lines changed

5 files changed

+34
-53
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self._hf_repo_id = hf_repo_id
108108
self._hf_config = hf_config
109109
self.config_kwargs = config_kwargs
110-
tfds_config = huggingface_utils.convert_hf_config_name(hf_config)
110+
tfds_config = huggingface_utils.convert_hf_name(hf_config)
111111
try:
112112
self._hf_builder = hf_datasets.load_dataset_builder(
113113
self._hf_repo_id, self._hf_config, **self.config_kwargs
@@ -128,7 +128,7 @@ def __init__(
128128
)
129129
else:
130130
self._converted_builder_config = None
131-
self.name = huggingface_utils.convert_hf_dataset_name(hf_repo_id)
131+
self.name = huggingface_utils.convert_hf_name(hf_repo_id)
132132
self._hf_hub_token = hf_hub_token
133133
self._hf_num_proc = hf_num_proc
134134
self._tfds_num_proc = tfds_num_proc
@@ -189,7 +189,8 @@ def _split_generators(
189189
self._hf_download_and_prepare()
190190
ds = self._hf_builder.as_dataset(verification_mode=self._verification_mode)
191191
splits = {
192-
split: self._generate_examples(data) for split, data in ds.items()
192+
huggingface_utils.convert_hf_name(split): self._generate_examples(data)
193+
for split, data in ds.items()
193194
}
194195
return _remove_empty_splits(splits)
195196

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _info(self):
4545
)
4646

4747
def _split_generators(self, dl_manager):
48-
return [hf_datasets.SplitGenerator(name=hf_datasets.Split.TRAIN)]
48+
return [hf_datasets.SplitGenerator(name='train.clean')]
4949

5050
def _generate_examples(self):
5151
for i in range(2):
@@ -101,7 +101,8 @@ def mock_huggingface_dataset_builder(
101101
def test_download_and_prepare(builder):
102102
builder.download_and_prepare()
103103
ds = builder.as_data_source()
104-
assert list(ds['train']) == [{'feature': 0}, {'feature': 1}]
104+
# Split names are sanitized, eg train.clean -> train_clean
105+
assert list(ds['train_clean']) == [{'feature': 0}, {'feature': 1}]
105106

106107

107108
def test_all_parameters_are_passed_down_to_hf(builder):

tensorflow_datasets/core/utils/huggingface_utils.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
from collections.abc import Mapping, Sequence
1919
import datetime
20-
from typing import Any, Type
20+
import re
21+
from typing import Any, Type, TypeVar
2122

2223
from etils import epath
2324
import immutabledict
@@ -39,6 +40,9 @@
3940
'string': np.object_,
4041
})
4142
_IMAGE_ENCODING_FORMAT = 'png'
43+
# Regular expression to match strings that are not valid Python/TFDS names:
44+
_INVALID_TFDS_NAME_CHARACTER = re.compile(r'[^a-zA-Z0-9_]')
45+
_StrOrNone = TypeVar('_StrOrNone', str, None)
4246

4347

4448
def _convert_to_np_dtype(hf_dtype: str) -> Type[np.generic]:
@@ -229,29 +233,27 @@ def convert_hf_value(
229233
)
230234

231235

232-
def convert_hf_dataset_name(hf_dataset_name: str) -> str:
233-
"""Converts Huggingface dataset name to a TFDS compatible dataset name.
236+
def convert_hf_name(hf_name: _StrOrNone) -> _StrOrNone:
237+
"""Converts Huggingface name to a TFDS compatible dataset name.
234238
235-
Huggingface dataset names can contain characters that are not supported in
239+
Huggingface names can contain characters that are not supported in
236240
TFDS. For example, in Huggingface a dataset name like `a/b` is supported,
237241
while in TFDS `b` would be parsed as the config.
238242
239243
Examples:
240-
- `hf_dataset_name='codeparrot/github-code'` becomes
241-
`codeparrot__github_code`.
244+
- `hf_name='codeparrot/github-code'` becomes `codeparrot__github_code`.
242245
243246
Args:
244-
hf_dataset_name: Huggingface dataset name.
247+
hf_name: Huggingface name.
245248
246249
Returns:
247-
The TFDS compatible dataset name.
250+
The TFDS compatible dataset name (dataset names, config names and split
251+
names).
248252
"""
249-
return (
250-
hf_dataset_name.replace('-', '_')
251-
.replace('.', '_')
252-
.replace('/', '__')
253-
.lower()
254-
)
253+
if hf_name is None:
254+
return hf_name
255+
hf_name = hf_name.lower().replace('/', '__')
256+
return re.sub(_INVALID_TFDS_NAME_CHARACTER, '_', hf_name)
255257

256258

257259
def convert_tfds_dataset_name(tfds_dataset_name: str) -> str:
@@ -271,22 +273,8 @@ def convert_tfds_dataset_name(tfds_dataset_name: str) -> str:
271273
existing Huggingface dataset.
272274
"""
273275
for hf_dataset_name in hf_datasets.list_datasets():
274-
if convert_hf_dataset_name(hf_dataset_name) == tfds_dataset_name.lower():
276+
if convert_hf_name(hf_dataset_name) == tfds_dataset_name.lower():
275277
return hf_dataset_name
276278
raise registered.DatasetNotFoundError(
277279
f'"{tfds_dataset_name}" is not listed in Huggingface datasets.'
278280
)
279-
280-
281-
def convert_hf_config_name(hf_config_name: str | None) -> str | None:
282-
"""Converts Huggingface config name to a TFDS compatible config name.
283-
284-
Args:
285-
hf_config_name: Optional Huggingface config name.
286-
287-
Returns:
288-
The TFDS compatible config name.
289-
"""
290-
if hf_config_name is None:
291-
return hf_config_name
292-
return hf_config_name.lower().replace(',', '_')

tensorflow_datasets/core/utils/huggingface_utils_test.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,22 @@ def test_convert_value(hf_value, feature, expected_value):
212212

213213

214214
@pytest.mark.parametrize(
215-
'hf_dataset_name,tfds_dataset_name',
215+
'hf_name,tfds_name',
216216
[
217+
# Dataset names
217218
('x', 'x'),
218219
('X', 'x'),
219220
('x-y', 'x_y'),
220221
('x/y', 'x__y'),
222+
('x/Y-z', 'x__y_z'),
223+
# Config and split names
224+
('x.y', 'x_y'),
221225
('x_v1.0', 'x_v1_0'),
226+
(None, None),
222227
],
223228
)
224-
def test_from_hf_to_tfds(hf_dataset_name, tfds_dataset_name):
225-
assert (
226-
huggingface_utils.convert_hf_dataset_name(hf_dataset_name)
227-
== tfds_dataset_name
228-
)
229+
def test_from_hf_to_tfds(hf_name, tfds_name):
230+
assert huggingface_utils.convert_hf_name(hf_name) == tfds_name
229231

230232

231233
@pytest.fixture(name='mock_list_datasets')
@@ -263,14 +265,3 @@ def test_convert_tfds_dataset_name(
263265
huggingface_utils.convert_tfds_dataset_name(tfds_dataset_name)
264266
== hf_dataset_name
265267
)
266-
267-
268-
@pytest.mark.parametrize(
269-
'hf_config_name,tfds_config_name',
270-
[(None, None), ('x', 'x'), ('X', 'x'), ('X,y', 'x_y')],
271-
)
272-
def test_convert_config_name(hf_config_name, tfds_config_name):
273-
assert (
274-
huggingface_utils.convert_hf_config_name(hf_config_name)
275-
== tfds_config_name
276-
)

tensorflow_datasets/scripts/documentation/build_community_catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,11 @@ def format_template(
255255
config_name: str, info: dataset_info_pb2.DatasetInfo
256256
) -> str:
257257
if self.namespace and self.namespace == 'huggingface':
258-
tfds_id = huggingface_utils.convert_hf_dataset_name(self.tfds_id)
258+
tfds_id = huggingface_utils.convert_hf_name(self.tfds_id)
259259
else:
260260
tfds_id = self.tfds_id
261261
if config_name != 'default':
262-
config_name = huggingface_utils.convert_hf_config_name(config_name)
262+
config_name = huggingface_utils.convert_hf_name(config_name)
263263
tfds_id = f'{tfds_id}/{config_name}'
264264
if keep_short:
265265
features = ''

0 commit comments

Comments
 (0)