34
34
import os
35
35
from typing import Any , Dict , Optional , Union
36
36
37
+ from absl import logging
37
38
from etils import epath
38
39
from tensorflow_datasets .core import dataset_builder
39
40
from tensorflow_datasets .core import dataset_info as dataset_info_lib
@@ -89,32 +90,64 @@ def shard_split(self) -> str:
89
90
return f'{ self .hf_split } [{ self .start_index } :{ self .end_index } ]'
90
91
91
92
93
+ @dataclasses .dataclass (frozen = True )
94
+ class _ShardInfo :
95
+ """Information about a shard after it is generated.
96
+
97
+ _ShardSpec is the input to the shard generation. This is the output.
98
+
99
+ Attributes:
100
+ num_bytes: Actual number of bytes in the shard.
101
+ num_examples: Actual number of examples in the shard.
102
+ num_exceptions: Number of exceptions during retrieval.
103
+ """
104
+
105
+ num_bytes : int
106
+ num_examples : int
107
+ num_exceptions : int
108
+
109
+
92
110
def _write_shard (
93
111
shard_spec : _ShardSpec ,
94
112
hf_builder ,
95
113
example_writer ,
96
114
features : feature_lib .FeaturesDict ,
97
- ) -> int :
115
+ ignore_hf_errors : bool ,
116
+ ) -> _ShardInfo :
98
117
"""Writes shard to the file.
99
118
100
119
Args:
101
120
shard_spec: Shard spec.
102
121
hf_builder: HuggingFace dataset builder.
103
122
example_writer: Example writer.
104
123
features: TFDS features dict.
124
+ ignore_hf_errors: Whether to silence and log Hugging Face errors during
125
+ retrieval.
105
126
106
127
Returns:
107
- Shard size in bytes .
128
+ A _ShardInfo containing the actual shard information .
108
129
"""
109
130
serialized_info = features .get_serialized_info ()
110
131
serializer = example_serializer .ExampleSerializer (serialized_info )
111
132
num_bytes = 0
133
+ num_exceptions = 0
112
134
113
135
def get_serialized_examples_iter ():
114
136
nonlocal num_bytes
115
- for hf_value in hf_builder .as_dataset (
137
+ nonlocal num_exceptions
138
+ dataset = hf_builder .as_dataset (
116
139
split = shard_spec .shard_split , run_post_process = False
117
- ):
140
+ )
141
+ for i in range (shard_spec .num_examples ):
142
+ try :
143
+ hf_value = dataset [i ]
144
+ except Exception : # pylint: disable=broad-exception-caught
145
+ num_exceptions += 1
146
+ if ignore_hf_errors :
147
+ logging .exception ('Ignoring Hugging Face error' )
148
+ continue
149
+ else :
150
+ raise
118
151
example = huggingface_utils .convert_hf_value (hf_value , features )
119
152
encoded_example = features .encode_example (example )
120
153
serialized_example = serializer .serialize_example (encoded_example )
@@ -133,7 +166,11 @@ def get_serialized_examples_iter():
133
166
),
134
167
)
135
168
136
- return num_bytes
169
+ return _ShardInfo (
170
+ num_bytes = num_bytes ,
171
+ num_examples = shard_spec .num_examples - num_exceptions ,
172
+ num_exceptions = num_exceptions ,
173
+ )
137
174
138
175
139
176
class HuggingfaceDatasetBuilder (
@@ -160,6 +197,7 @@ def __init__(
160
197
hf_hub_token : Optional [str ] = None ,
161
198
hf_num_proc : Optional [int ] = None ,
162
199
tfds_num_proc : Optional [int ] = None ,
200
+ ignore_hf_errors : bool = False ,
163
201
** config_kwargs ,
164
202
):
165
203
self ._hf_repo_id = hf_repo_id
@@ -199,6 +237,7 @@ def __init__(
199
237
if self ._hf_config :
200
238
self ._builder_config = self ._converted_builder_config
201
239
self .generation_errors = []
240
+ self ._ignore_hf_errors = ignore_hf_errors
202
241
203
242
@property
204
243
def builder_config (self ) -> Optional [Any ]:
@@ -262,19 +301,20 @@ def _generate_splits(
262
301
hf_split_info , split
263
302
)
264
303
265
- shard_sizes_by_split = self ._write_shards (shard_specs_by_split )
266
-
267
- return [
268
- splits_lib .SplitInfo (
269
- name = split ,
270
- shard_lengths = [
271
- shard_spec .num_examples for shard_spec in shard_specs
272
- ],
273
- num_bytes = sum (shard_sizes_by_split [split ]),
274
- filename_template = self ._get_filename_template (split ),
275
- )
276
- for split , shard_specs in shard_specs_by_split .items ()
277
- ]
304
+ shard_infos_by_split = self ._write_shards (shard_specs_by_split )
305
+ split_infos : list [splits_lib .SplitInfo ] = []
306
+ for split , shard_infos in shard_infos_by_split .items ():
307
+ shard_lengths = [shard_info .num_examples for shard_info in shard_infos ]
308
+ num_bytes = sum (shard_info .num_bytes for shard_info in shard_infos )
309
+ split_infos .append (
310
+ splits_lib .SplitInfo (
311
+ name = split ,
312
+ shard_lengths = shard_lengths ,
313
+ num_bytes = num_bytes ,
314
+ filename_template = self ._get_filename_template (split ),
315
+ )
316
+ )
317
+ return split_infos
278
318
279
319
def _compute_shard_specs (
280
320
self , hf_split_info : hf_datasets .SplitInfo , split : str
@@ -318,7 +358,7 @@ def _compute_shard_specs(
318
358
def _write_shards (
319
359
self ,
320
360
shard_specs_by_split : Mapping [str , Sequence [_ShardSpec ]],
321
- ) -> Mapping [str , Sequence [int ]]:
361
+ ) -> Mapping [str , Sequence [_ShardInfo ]]:
322
362
"""Writes shards to files.
323
363
324
364
Args:
@@ -340,22 +380,32 @@ def _write_shards(
340
380
hf_builder = self ._hf_builder ,
341
381
example_writer = self ._example_writer (),
342
382
features = self .info .features ,
383
+ ignore_hf_errors = self ._ignore_hf_errors ,
343
384
)
344
385
345
386
if self ._tfds_num_proc is None :
346
- shard_sizes = list (map (write_shard , shard_specs ))
387
+ shard_infos = list (map (write_shard , shard_specs ))
347
388
else :
348
389
with multiprocessing .Pool (processes = self ._tfds_num_proc ) as pool :
349
- shard_sizes = pool .map (write_shard , shard_specs )
390
+ shard_infos = pool .map (write_shard , shard_specs )
350
391
351
392
shard_idx = 0
352
- shard_sizes_by_split : dict [str , Sequence [int ]] = {}
393
+ shard_infos_by_split : dict [str , Sequence [_ShardInfo ]] = {}
353
394
for split , shard_specs in shard_specs_by_split .items ():
354
- shard_sizes_by_split [split ] = shard_sizes [
395
+ shard_infos_by_split [split ] = shard_infos [
355
396
shard_idx : shard_idx + len (shard_specs )
356
397
]
357
398
shard_idx += len (shard_specs )
358
- return shard_sizes_by_split
399
+ expected_num_examples = sum (spec .num_examples for spec in shard_specs )
400
+ if self ._ignore_hf_errors and expected_num_examples > 0 :
401
+ num_exceptions = sum (info .num_exceptions for info in shard_infos )
402
+ percentage_exceptions = num_exceptions / expected_num_examples * 100
403
+ logging .info (
404
+ 'Got %d exceptions (%.2f%%) during Hugging Face generation' ,
405
+ num_exceptions ,
406
+ percentage_exceptions ,
407
+ )
408
+ return shard_infos_by_split
359
409
360
410
361
411
def builder (
0 commit comments