26
26
27
27
from __future__ import annotations
28
28
29
+ from collections .abc import Mapping , Sequence
30
+ import dataclasses
29
31
import functools
30
32
import itertools
31
33
import multiprocessing
32
34
import os
33
35
from typing import Any , Dict , Optional , Union
34
36
35
- from absl import logging
36
37
from etils import epath
37
38
from tensorflow_datasets .core import dataset_builder
38
39
from tensorflow_datasets .core import dataset_info as dataset_info_lib
39
40
from tensorflow_datasets .core import download
41
+ from tensorflow_datasets .core import example_serializer
42
+ from tensorflow_datasets .core import features as feature_lib
40
43
from tensorflow_datasets .core import file_adapters
41
44
from tensorflow_datasets .core import lazy_imports_lib
42
45
from tensorflow_datasets .core import split_builder as split_builder_lib
43
46
from tensorflow_datasets .core import splits as splits_lib
44
47
from tensorflow_datasets .core .utils import huggingface_utils
48
+ from tensorflow_datasets .core .utils import shard_utils
49
+ from tensorflow_datasets .core .utils import tqdm_utils
45
50
from tensorflow_datasets .core .utils import version as version_lib
46
51
from tensorflow_datasets .core .utils .lazy_imports_utils import datasets as hf_datasets
47
52
48
- _EMPTY_SPLIT_WARNING_MSG = "%s split doesn't have any examples"
49
-
50
53
51
54
def _extract_supervised_keys (hf_info ):
52
55
if hf_info .supervised_keys is not None :
@@ -57,24 +60,79 @@ def _extract_supervised_keys(hf_info):
57
60
return None
58
61
59
62
60
- def _remove_empty_splits (
61
- splits : Dict [str , split_builder_lib .SplitGenerator ]
62
- ) -> Dict [str , split_builder_lib .SplitGenerator ]:
63
- """Removes empty splits."""
64
- non_empty_splits = {}
63
+ @dataclasses .dataclass (frozen = True )
64
+ class _ShardSpec :
65
+ """Spec to write a shard.
66
+
67
+ Attributes:
68
+ path: Shard path.
69
+ hf_split: HuggingFace split name.
70
+ split: TFDS split name.
71
+ start_index: Index of the shard start.
72
+ end_index: Index of the shard end.
73
+ num_examples: Number of examples in the shard.
74
+ shard_split: HuggingFace split for the shard.
75
+ """
76
+
77
+ path : epath .Path
78
+ hf_split : str
79
+ split : str
80
+ start_index : int
81
+ end_index : int
82
+
83
+ @property
84
+ def num_examples (self ) -> int :
85
+ return self .end_index - self .start_index
86
+
87
+ @property
88
+ def shard_split (self ) -> str :
89
+ return f'{ self .hf_split } [{ self .start_index } :{ self .end_index } ]'
65
90
66
- for split , examples_iterable in splits .items ():
67
- examples_iterator = iter (examples_iterable )
68
- # ensure the iterator is not empty
69
- try :
70
- first_example = next (examples_iterator )
71
- non_empty_splits [split ] = itertools .chain (
72
- [first_example ], examples_iterator
73
- )
74
- except StopIteration :
75
- logging .warning (_EMPTY_SPLIT_WARNING_MSG , split )
76
91
77
- return non_empty_splits
92
+ def _write_shard (
93
+ shard_spec : _ShardSpec ,
94
+ hf_builder ,
95
+ example_writer ,
96
+ features : feature_lib .FeaturesDict ,
97
+ ) -> int :
98
+ """Writes shard to the file.
99
+
100
+ Args:
101
+ shard_spec: Shard spec.
102
+ hf_builder: HuggingFace dataset builder.
103
+ example_writer: Example writer.
104
+ features: TFDS features dict.
105
+
106
+ Returns:
107
+ Shard size in bytes.
108
+ """
109
+ serialized_info = features .get_serialized_info ()
110
+ serializer = example_serializer .ExampleSerializer (serialized_info )
111
+ num_bytes = 0
112
+
113
+ def get_serialized_examples_iter ():
114
+ nonlocal num_bytes
115
+ for hf_value in hf_builder .as_dataset (
116
+ split = shard_spec .shard_split , run_post_process = False
117
+ ):
118
+ example = huggingface_utils .convert_hf_value (hf_value , features )
119
+ serialized_example = serializer .serialize_example (example )
120
+ num_bytes += len (serialized_example )
121
+ yield serialized_example
122
+
123
+ example_writer .write (
124
+ os .fspath (shard_spec .path ),
125
+ tqdm_utils .tqdm (
126
+ enumerate (get_serialized_examples_iter ()),
127
+ desc = f'Writing { shard_spec .path } examples...' ,
128
+ unit = ' examples' ,
129
+ total = shard_spec .num_examples ,
130
+ leave = False ,
131
+ mininterval = 1.0 ,
132
+ ),
133
+ )
134
+
135
+ return num_bytes
78
136
79
137
80
138
class HuggingfaceDatasetBuilder (
@@ -164,7 +222,7 @@ def _hf_download_and_prepare(self):
164
222
def _hf_info (self ) -> hf_datasets .DatasetInfo :
165
223
return self ._hf_builder .info
166
224
167
- def _hf_features (self ):
225
+ def _hf_features (self ) -> hf_datasets . Features :
168
226
if not self ._hf_info .features :
169
227
# We need to download and prepare the data to know its features.
170
228
self ._hf_download_and_prepare ()
@@ -185,24 +243,121 @@ def _info(self) -> dataset_info_lib.DatasetInfo:
185
243
def _split_generators (
186
244
self , dl_manager : download .DownloadManager
187
245
) -> Dict [splits_lib .Split , split_builder_lib .SplitGenerator ]:
188
- del dl_manager
189
- self ._hf_download_and_prepare ()
190
- ds = self ._hf_builder .as_dataset (verification_mode = self ._verification_mode )
191
- splits = {
192
- huggingface_utils .convert_hf_name (split ): self ._generate_examples (data )
193
- for split , data in ds .items ()
194
- }
195
- return _remove_empty_splits (splits )
246
+ raise NotImplementedError ('This method should not be called.' )
196
247
197
248
def _generate_examples (self , data ) -> split_builder_lib .SplitGenerator :
198
- convert_example = functools .partial (
199
- huggingface_utils .convert_hf_value , feature = self .info .features
249
+ raise NotImplementedError ('This method should not be called.' )
250
+
251
+ def _generate_splits (
252
+ self ,
253
+ dl_manager : download .DownloadManager ,
254
+ download_config : download .DownloadConfig ,
255
+ ) -> Sequence [splits_lib .SplitInfo ]:
256
+ """Prepares the dataset by writing to shards directly."""
257
+ del dl_manager , download_config # Unused.
258
+ self ._hf_download_and_prepare ()
259
+
260
+ shard_specs_by_split : dict [str , Sequence [_ShardSpec ]] = {}
261
+ for hf_split , hf_split_info in self ._hf_info .splits .items ():
262
+ split = huggingface_utils .convert_hf_name (hf_split )
263
+ shard_specs_by_split [split ] = self ._compute_shard_specs (
264
+ hf_split_info , split
265
+ )
266
+
267
+ shard_sizes_by_split = self ._write_shards (shard_specs_by_split )
268
+
269
+ return [
270
+ splits_lib .SplitInfo (
271
+ name = split ,
272
+ shard_lengths = [
273
+ shard_spec .num_examples for shard_spec in shard_specs
274
+ ],
275
+ num_bytes = sum (shard_sizes_by_split [split ]),
276
+ filename_template = self ._get_filename_template (split ),
277
+ )
278
+ for split , shard_specs in shard_specs_by_split .items ()
279
+ ]
280
+
281
+ def _compute_shard_specs (
282
+ self , hf_split_info : hf_datasets .SplitInfo , split : str
283
+ ) -> Sequence [_ShardSpec ]:
284
+ """Returns specs for evenly spread shards.
285
+
286
+ Args:
287
+ hf_split_info: HuggingFace split info.
288
+ split: TFDS split name.
289
+ """
290
+ # HF split size is good enough for estimating the number of shards.
291
+ num_shards = shard_utils .ShardConfig .calculate_number_shards (
292
+ total_size = hf_split_info .num_bytes ,
293
+ num_examples = hf_split_info .num_examples ,
294
+ uses_precise_sharding = False ,
295
+ )
296
+ filename_template = self ._get_filename_template (split )
297
+ shard_boundaries = shard_utils .get_shard_boundaries (
298
+ num_examples = hf_split_info .num_examples , number_of_shards = num_shards
299
+ )
300
+
301
+ prev_shard_boundary = 0
302
+ shard_specs : list [_ShardSpec ] = []
303
+
304
+ for shard_index , shard_boundary in enumerate (shard_boundaries ):
305
+ shard_specs .append (
306
+ _ShardSpec (
307
+ path = filename_template .sharded_filepath (
308
+ shard_index = shard_index , num_shards = len (shard_boundaries )
309
+ ),
310
+ hf_split = hf_split_info .name ,
311
+ split = split ,
312
+ start_index = prev_shard_boundary ,
313
+ end_index = shard_boundary ,
314
+ )
315
+ )
316
+ prev_shard_boundary = shard_boundary
317
+
318
+ return shard_specs
319
+
320
+ def _write_shards (
321
+ self ,
322
+ shard_specs_by_split : Mapping [str , Sequence [_ShardSpec ]],
323
+ ) -> Mapping [str , Sequence [int ]]:
324
+ """Writes shards to files.
325
+
326
+ Args:
327
+ shard_specs_by_split: Shard specs by split name.
328
+
329
+ Returns:
330
+ Shard sizes in bytes.
331
+ """
332
+ shard_specs = list (itertools .chain (* shard_specs_by_split .values ()))
333
+ shard_specs = tqdm_utils .tqdm (
334
+ shard_specs ,
335
+ desc = 'Writing shards...' ,
336
+ unit = ' shards' ,
337
+ total = len (shard_specs ),
338
+ leave = False ,
339
+ )
340
+ write_shard = functools .partial (
341
+ _write_shard ,
342
+ hf_builder = self ._hf_builder ,
343
+ example_writer = self ._example_writer (),
344
+ features = self .info .features ,
200
345
)
346
+
201
347
if self ._tfds_num_proc is None :
202
- yield from enumerate (map (convert_example , data ))
348
+ shard_sizes = list (map (write_shard , shard_specs ))
203
349
else :
204
350
with multiprocessing .Pool (processes = self ._tfds_num_proc ) as pool :
205
- yield from enumerate (pool .imap (convert_example , data ))
351
+ shard_sizes = pool .map (write_shard , shard_specs )
352
+
353
+ shard_idx = 0
354
+ shard_sizes_by_split : dict [str , Sequence [int ]] = {}
355
+ for split , shard_specs in shard_specs_by_split .items ():
356
+ shard_sizes_by_split [split ] = shard_sizes [
357
+ shard_idx : shard_idx + len (shard_specs )
358
+ ]
359
+ shard_idx += len (shard_specs )
360
+ return shard_sizes_by_split
206
361
207
362
208
363
def builder (
0 commit comments