Skip to content

Commit 680df0b

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Add the possibility to silence errors for Hugging Face.
PiperOrigin-RevId: 626937096
1 parent df39eff commit 680df0b

File tree

2 files changed

+75
-24
lines changed

2 files changed

+75
-24
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import os
3535
from typing import Any, Dict, Optional, Union
3636

37+
from absl import logging
3738
from etils import epath
3839
from tensorflow_datasets.core import dataset_builder
3940
from tensorflow_datasets.core import dataset_info as dataset_info_lib
@@ -89,32 +90,64 @@ def shard_split(self) -> str:
8990
return f'{self.hf_split}[{self.start_index}:{self.end_index}]'
9091

9192

93+
@dataclasses.dataclass(frozen=True)
94+
class _ShardInfo:
95+
"""Information about a shard after it is generated.
96+
97+
_ShardSpec is the input to the shard generation. This is the output.
98+
99+
Attributes:
100+
num_bytes: Actual number of bytes in the shard.
101+
num_examples: Actual number of examples in the shard.
102+
num_exceptions: Number of exceptions during retrieval.
103+
"""
104+
105+
num_bytes: int
106+
num_examples: int
107+
num_exceptions: int
108+
109+
92110
def _write_shard(
93111
shard_spec: _ShardSpec,
94112
hf_builder,
95113
example_writer,
96114
features: feature_lib.FeaturesDict,
97-
) -> int:
115+
ignore_hf_errors: bool,
116+
) -> _ShardInfo:
98117
"""Writes shard to the file.
99118
100119
Args:
101120
shard_spec: Shard spec.
102121
hf_builder: HuggingFace dataset builder.
103122
example_writer: Example writer.
104123
features: TFDS features dict.
124+
ignore_hf_errors: Whether to silence and log Hugging Face errors during
125+
retrieval.
105126
106127
Returns:
107-
Shard size in bytes.
128+
A _ShardInfo containing the actual shard information.
108129
"""
109130
serialized_info = features.get_serialized_info()
110131
serializer = example_serializer.ExampleSerializer(serialized_info)
111132
num_bytes = 0
133+
num_exceptions = 0
112134

113135
def get_serialized_examples_iter():
114136
nonlocal num_bytes
115-
for hf_value in hf_builder.as_dataset(
137+
nonlocal num_exceptions
138+
dataset = hf_builder.as_dataset(
116139
split=shard_spec.shard_split, run_post_process=False
117-
):
140+
)
141+
for i in range(shard_spec.num_examples):
142+
try:
143+
hf_value = dataset[i]
144+
except Exception: # pylint: disable=broad-exception-caught
145+
num_exceptions += 1
146+
if ignore_hf_errors:
147+
logging.exception('Ignoring Hugging Face error')
148+
continue
149+
else:
150+
raise
118151
example = huggingface_utils.convert_hf_value(hf_value, features)
119152
encoded_example = features.encode_example(example)
120153
serialized_example = serializer.serialize_example(encoded_example)
@@ -133,7 +166,11 @@ def get_serialized_examples_iter():
133166
),
134167
)
135168

136-
return num_bytes
169+
return _ShardInfo(
170+
num_bytes=num_bytes,
171+
num_examples=shard_spec.num_examples - num_exceptions,
172+
num_exceptions=num_exceptions,
173+
)
137174

138175

139176
class HuggingfaceDatasetBuilder(
@@ -160,6 +197,7 @@ def __init__(
160197
hf_hub_token: Optional[str] = None,
161198
hf_num_proc: Optional[int] = None,
162199
tfds_num_proc: Optional[int] = None,
200+
ignore_hf_errors: bool = False,
163201
**config_kwargs,
164202
):
165203
self._hf_repo_id = hf_repo_id
@@ -199,6 +237,7 @@ def __init__(
199237
if self._hf_config:
200238
self._builder_config = self._converted_builder_config
201239
self.generation_errors = []
240+
self._ignore_hf_errors = ignore_hf_errors
202241

203242
@property
204243
def builder_config(self) -> Optional[Any]:
@@ -262,19 +301,20 @@ def _generate_splits(
262301
hf_split_info, split
263302
)
264303

265-
shard_sizes_by_split = self._write_shards(shard_specs_by_split)
266-
267-
return [
268-
splits_lib.SplitInfo(
269-
name=split,
270-
shard_lengths=[
271-
shard_spec.num_examples for shard_spec in shard_specs
272-
],
273-
num_bytes=sum(shard_sizes_by_split[split]),
274-
filename_template=self._get_filename_template(split),
275-
)
276-
for split, shard_specs in shard_specs_by_split.items()
277-
]
304+
shard_infos_by_split = self._write_shards(shard_specs_by_split)
305+
split_infos: list[splits_lib.SplitInfo] = []
306+
for split, shard_infos in shard_infos_by_split.items():
307+
shard_lengths = [shard_info.num_examples for shard_info in shard_infos]
308+
num_bytes = sum(shard_info.num_bytes for shard_info in shard_infos)
309+
split_infos.append(
310+
splits_lib.SplitInfo(
311+
name=split,
312+
shard_lengths=shard_lengths,
313+
num_bytes=num_bytes,
314+
filename_template=self._get_filename_template(split),
315+
)
316+
)
317+
return split_infos
278318

279319
def _compute_shard_specs(
280320
self, hf_split_info: hf_datasets.SplitInfo, split: str
@@ -318,7 +358,7 @@ def _compute_shard_specs(
318358
def _write_shards(
319359
self,
320360
shard_specs_by_split: Mapping[str, Sequence[_ShardSpec]],
321-
) -> Mapping[str, Sequence[int]]:
361+
) -> Mapping[str, Sequence[_ShardInfo]]:
322362
"""Writes shards to files.
323363
324364
Args:
@@ -340,22 +380,32 @@ def _write_shards(
340380
hf_builder=self._hf_builder,
341381
example_writer=self._example_writer(),
342382
features=self.info.features,
383+
ignore_hf_errors=self._ignore_hf_errors,
343384
)
344385

345386
if self._tfds_num_proc is None:
346-
shard_sizes = list(map(write_shard, shard_specs))
387+
shard_infos = list(map(write_shard, shard_specs))
347388
else:
348389
with multiprocessing.Pool(processes=self._tfds_num_proc) as pool:
349-
shard_sizes = pool.map(write_shard, shard_specs)
390+
shard_infos = pool.map(write_shard, shard_specs)
350391

351392
shard_idx = 0
352-
shard_sizes_by_split: dict[str, Sequence[int]] = {}
393+
shard_infos_by_split: dict[str, Sequence[_ShardInfo]] = {}
353394
for split, shard_specs in shard_specs_by_split.items():
354-
shard_sizes_by_split[split] = shard_sizes[
395+
shard_infos_by_split[split] = shard_infos[
355396
shard_idx : shard_idx + len(shard_specs)
356397
]
357398
shard_idx += len(shard_specs)
358-
return shard_sizes_by_split
399+
expected_num_examples = sum(spec.num_examples for spec in shard_specs)
400+
if self._ignore_hf_errors and expected_num_examples > 0:
401+
num_exceptions = sum(info.num_exceptions for info in shard_infos)
402+
percentage_exceptions = num_exceptions / expected_num_examples * 100
403+
logging.info(
404+
'Got %d exceptions (%.2f%%) during Hugging Face generation',
405+
num_exceptions,
406+
percentage_exceptions,
407+
)
408+
return shard_infos_by_split
359409

360410

361411
def builder(

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_download_and_prepare(builder):
106106
):
107107
for feature in ['number', 'text', 'image']:
108108
assert np.array_equal(element[feature], expected[feature])
109+
assert len(ds['train_clean']) == 2
109110

110111

111112
def test_all_parameters_are_passed_down_to_hf(builder):

0 commit comments

Comments
 (0)