Skip to content

Commit cb012e6

Browse files
author
The TensorFlow Datasets Authors
committed
Refactor conversion functions from huggingface_utils to conversion_utils.
PiperOrigin-RevId: 653322590
1 parent 6d37846 commit cb012e6

File tree

8 files changed

+359
-317
lines changed

8 files changed

+359
-317
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
from tensorflow_datasets.core.features import image_feature
5050
from tensorflow_datasets.core.features import sequence_feature
5151
from tensorflow_datasets.core.features import text_feature
52+
from tensorflow_datasets.core.utils import conversion_utils
5253
from tensorflow_datasets.core.utils import croissant_utils
53-
from tensorflow_datasets.core.utils import huggingface_utils
5454
from tensorflow_datasets.core.utils import type_utils
5555
from tensorflow_datasets.core.utils import version as version_utils
5656
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
@@ -194,7 +194,7 @@ def __init__(
194194
if not record_set_ids:
195195
record_set_ids = croissant_utils.get_record_set_ids(self.metadata)
196196
config_names = [
197-
huggingface_utils.convert_hf_name(record_set)
197+
conversion_utils.to_tfds_name(record_set)
198198
for record_set in record_set_ids
199199
]
200200
self.BUILDER_CONFIGS: Sequence[dataset_builder.BuilderConfig] = [ # pylint: disable=invalid-name
@@ -290,7 +290,7 @@ def _generate_examples(
290290
# Some samples might not be TFDS-compatible as-is, e.g. from croissant
291291
# describing HuggingFace datasets, so we convert them here. This shouldn't
292292
# impact datasets which are already TFDS-compatible.
293-
record = huggingface_utils.convert_hf_value(record, self.info.features)
293+
record = conversion_utils.to_tfds_value(record, self.info.features)
294294
# After partition implementation, the filters will be applied from
295295
# mlcroissant `dataset.records` directly.
296296
# `records = records.filter(f == v for f, v in filters.items())``

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from tensorflow_datasets.core import file_adapters
4545
from tensorflow_datasets.core import split_builder as split_builder_lib
4646
from tensorflow_datasets.core import splits as splits_lib
47+
from tensorflow_datasets.core.utils import conversion_utils
4748
from tensorflow_datasets.core.utils import huggingface_utils
4849
from tensorflow_datasets.core.utils import shard_utils
4950
from tensorflow_datasets.core.utils import tqdm_utils
@@ -148,7 +149,7 @@ def get_serialized_examples_iter():
148149
continue
149150
else:
150151
raise
151-
example = huggingface_utils.convert_hf_value(hf_value, features)
152+
example = conversion_utils.to_tfds_value(hf_value, features)
152153
encoded_example = features.encode_example(example)
153154
serialized_example = serializer.serialize_example(encoded_example)
154155
num_bytes += len(serialized_example)
@@ -205,7 +206,7 @@ def __init__(
205206
self._hf_config = hf_config
206207
self.config_kwargs = config_kwargs
207208
tfds_config = (
208-
huggingface_utils.convert_hf_name(hf_config) if hf_config else None
209+
conversion_utils.to_tfds_name(hf_config) if hf_config else None
209210
)
210211
try:
211212
self._hf_builder = hf_datasets.load_dataset_builder(
@@ -224,7 +225,7 @@ def __init__(
224225
or '1.0.0'
225226
)
226227
self.VERSION = version_lib.Version(version) # pylint: disable=invalid-name
227-
self.name = huggingface_utils.convert_hf_name(hf_repo_id)
228+
self.name = conversion_utils.to_tfds_name(hf_repo_id)
228229
self._hf_hub_token = hf_hub_token
229230
self._hf_num_proc = hf_num_proc
230231
self._tfds_num_proc = tfds_num_proc
@@ -312,7 +313,7 @@ def _generate_splits(
312313

313314
shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {}
314315
for hf_split, hf_split_info in self._hf_info.splits.items():
315-
split = huggingface_utils.convert_hf_name(hf_split)
316+
split = conversion_utils.to_tfds_name(hf_split)
316317
shard_specs_by_split[split] = self._compute_shard_specs(
317318
hf_split_info, split
318319
)
@@ -455,7 +456,7 @@ def _get_text_field(self, field: str) -> str | None:
455456
def builder(
456457
name: str, config: Optional[str] = None, **builder_kwargs
457458
) -> HuggingfaceDatasetBuilder:
458-
hf_repo_id = huggingface_utils.convert_tfds_dataset_name(name)
459+
hf_repo_id = huggingface_utils.to_huggingface_name(name)
459460
return HuggingfaceDatasetBuilder(
460461
hf_repo_id=hf_repo_id, hf_config=config, **builder_kwargs
461462
)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# coding=utf-8
2+
# Copyright 2024 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utility functions to convert from other formats to TFDS conventions."""
17+
18+
from collections.abc import Mapping, Sequence
19+
import datetime
20+
from typing import Any
21+
22+
from etils import epath
23+
import numpy as np
24+
from tensorflow_datasets.core import features as feature_lib
25+
from tensorflow_datasets.core import lazy_imports_lib
26+
from tensorflow_datasets.core.utils import dtype_utils
27+
from tensorflow_datasets.core.utils import py_utils
28+
29+
_DEFAULT_IMG = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178U\x00\x00\x00\x00IEND\xaeB`\x82'
30+
31+
32+
def to_tfds_name(name: str) -> str:
33+
"""Converts a name to a TFDS compatible dataset name.
34+
35+
Huggingface names can contain characters that are not supported in
36+
TFDS. For example, in Huggingface a dataset name like `a/b` is supported,
37+
while in TFDS `b` would be parsed as the config.
38+
39+
Examples:
40+
- `name='codeparrot/github-code'` becomes `codeparrot__github_code`.
41+
42+
Args:
43+
name: A name to be converted to a TFDS compatible name.
44+
45+
Returns:
46+
The TFDS compatible dataset name (dataset names, config names and split
47+
names).
48+
"""
49+
name = name.lower().replace('/', '__')
50+
return py_utils.make_valid_name(name)
51+
52+
53+
def _get_default_value(
54+
feature: feature_lib.FeatureConnector,
55+
) -> Mapping[str, Any] | Sequence[Any] | bytes | int | float | bool:
56+
"""Returns the default value for a feature.
57+
58+
Non-TFDS features can be loose as far as typing is concerned. For example,
59+
HuggingFace accepts None values. As long as `tfds.features.Optional` does not
60+
exist, we default to a constant default value.
61+
62+
For int and float, we do not return 0 or -1, but rather -inf, as 0 or -1 can
63+
be contained in the values of the dataset. In practice, you can compare your
64+
value to:
65+
66+
```
67+
np.iinfo(np.int32).min # for integers
68+
np.finfo(np.float32).min # for floats
69+
...
70+
```
71+
72+
For None images, we set a default value which corresponds to a PNG of 1px,
73+
black.
74+
75+
Args:
76+
feature: The TFDS feature from which we want the default value.
77+
78+
Raises:
79+
TypeError: If couldn't recognize feature dtype.
80+
"""
81+
match feature:
82+
case feature_lib.FeaturesDict():
83+
return {
84+
name: _get_default_value(inner_feature)
85+
for name, inner_feature in feature.items()
86+
}
87+
case feature_lib.Sequence():
88+
match feature.feature:
89+
case feature_lib.FeaturesDict():
90+
return {feature_name: [] for feature_name in feature.feature.keys()}
91+
case _:
92+
return []
93+
case feature_lib.Image():
94+
# Return an empty PNG image of 1x1 pixel, black.
95+
return _DEFAULT_IMG
96+
case _:
97+
if dtype_utils.is_string(feature.np_dtype):
98+
return b''
99+
elif dtype_utils.is_integer(feature.np_dtype):
100+
return np.iinfo(feature.np_dtype).min
101+
elif dtype_utils.is_floating(feature.np_dtype):
102+
return np.finfo(feature.np_dtype).min
103+
elif dtype_utils.is_bool(feature.np_dtype):
104+
return False
105+
else:
106+
raise TypeError(f'Could not recognize the dtype of {feature}')
107+
108+
109+
def to_tfds_value(value: Any, feature: feature_lib.FeatureConnector) -> Any:
110+
"""Converts a value to a TFDS compatible value.
111+
112+
Args:
113+
value: The value to be converted to follow TFDS conventions.
114+
feature: The TFDS feature for which we want the compatible value.
115+
116+
Returns:
117+
The TFDS compatible value.
118+
119+
Raises:
120+
TypeError: If couldn't recognize the given feature type.
121+
"""
122+
match value:
123+
case None:
124+
return _get_default_value(feature)
125+
case datetime.datetime():
126+
return int(value.timestamp())
127+
128+
match feature:
129+
case feature_lib.ClassLabel() | feature_lib.Scalar():
130+
return value
131+
case feature_lib.FeaturesDict():
132+
return {
133+
name: to_tfds_value(value.get(name), inner_feature)
134+
for name, inner_feature in feature.items()
135+
}
136+
case feature_lib.Sequence():
137+
match value:
138+
case dict():
139+
# Should be a dict of lists:
140+
return {
141+
name: [
142+
to_tfds_value(inner_hf_value, inner_feature)
143+
for inner_hf_value in value.get(name)
144+
]
145+
for name, inner_feature in feature.feature.items()
146+
}
147+
case list():
148+
return [
149+
to_tfds_value(inner_hf_value, feature.feature)
150+
for inner_hf_value in value
151+
]
152+
case _:
153+
return [value]
154+
case feature_lib.Audio():
155+
if array := value.get('array'):
156+
# Hugging Face uses floats, TFDS uses integers.
157+
return [int(sample * feature.sample_rate) for sample in array]
158+
elif (path := value.get('path')) and (path := epath.Path(path)).exists():
159+
return path
160+
case feature_lib.Image():
161+
value: lazy_imports_lib.lazy_imports.PIL_Image.Image
162+
# Ensure RGB format for PNG encoding.
163+
return value.convert('RGB')
164+
case feature_lib.Tensor():
165+
if isinstance(value, float):
166+
# In some cases, for example when loading jsonline files using pandas,
167+
# empty non-float values, such as strings, are converted to float nan.
168+
# We spot those occurrences as the feature.np_dtype is not float.
169+
if np.isnan(value) and not dtype_utils.is_floating(feature.np_dtype):
170+
return _get_default_value(feature)
171+
return value
172+
173+
raise TypeError(
174+
f'Conversion of value {value} to feature {feature} is not supported.'
175+
)

0 commit comments

Comments
 (0)