Skip to content

Commit f27c205

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Optimize HuggingFace dataset preparation.
PiperOrigin-RevId: 616102920
1 parent 2f6c156 commit f27c205

File tree

3 files changed

+217
-64
lines changed

3 files changed

+217
-64
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,12 +1560,25 @@ def _example_writer(self) -> writer_lib.ExampleWriter:
15601560
"""
15611561
return writer_lib.ExampleWriter(file_format=self.info.file_format)
15621562

1563-
def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
1563+
def _get_filename_template(
1564+
self, split_name: str
1565+
) -> naming.ShardedFileTemplate:
1566+
"""Returns a filename template for the given split."""
1567+
return naming.ShardedFileTemplate(
1568+
split=split_name,
1569+
dataset_name=self.name,
1570+
data_dir=self.data_path,
1571+
filetype_suffix=file_adapters.ADAPTER_FOR_FORMAT[
1572+
self.info.file_format
1573+
].FILE_SUFFIX,
1574+
)
1575+
1576+
def _generate_splits(
15641577
self,
15651578
dl_manager: download.DownloadManager,
15661579
download_config: download.DownloadConfig,
1567-
) -> None:
1568-
"""Generate all splits and returns the computed split infos."""
1580+
) -> Sequence[splits_lib.SplitInfo]:
1581+
"""Generates all splits and returns the computed split infos."""
15691582
split_builder = split_builder_lib.SplitBuilder(
15701583
split_dict=self.info.splits,
15711584
features=self.info.features,
@@ -1610,28 +1623,15 @@ def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-p
16101623
# Ensure `all` isn't used as key.
16111624
_check_split_names(split_generators.keys())
16121625

1613-
# Writer fail if the number of example yield is `0`, so we return here.
1614-
if download_config.max_examples_per_split == 0:
1615-
return
1616-
16171626
# Start generating data for all splits
1618-
path_suffix = file_adapters.ADAPTER_FOR_FORMAT[
1619-
self.info.file_format
1620-
].FILE_SUFFIX
1621-
16221627
split_info_futures = []
16231628
for split_name, generator in utils.tqdm(
16241629
split_generators.items(),
16251630
desc="Generating splits...",
16261631
unit=" splits",
16271632
leave=False,
16281633
):
1629-
filename_template = naming.ShardedFileTemplate(
1630-
split=split_name,
1631-
dataset_name=self.name,
1632-
data_dir=self.data_path,
1633-
filetype_suffix=path_suffix,
1634-
)
1634+
filename_template = self._get_filename_template(split_name=split_name)
16351635
future = split_builder.submit_split_generation(
16361636
split_name=split_name,
16371637
generator=generator,
@@ -1645,7 +1645,19 @@ def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-p
16451645
self._process_pipeline_result(pipeline_result=maybe_pipeline_proxy.result)
16461646

16471647
# Finalize the splits (after apache beam completed, if it was used)
1648-
split_infos = [future.result() for future in split_info_futures]
1648+
return [future.result() for future in split_info_futures]
1649+
1650+
def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
1651+
self,
1652+
dl_manager: download.DownloadManager,
1653+
download_config: download.DownloadConfig,
1654+
) -> None:
1655+
"""Generates all splits and sets the computed split infos."""
1656+
# Writer fails if the number of example yield is `0`, so we return here.
1657+
if download_config.max_examples_per_split == 0:
1658+
return
1659+
1660+
split_infos = self._generate_splits(dl_manager, download_config)
16491661

16501662
# Update the info object with the splits.
16511663
split_dict = splits_lib.SplitDict(split_infos)

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 187 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,30 @@
2626

2727
from __future__ import annotations
2828

29+
from collections.abc import Mapping, Sequence
30+
import dataclasses
2931
import functools
3032
import itertools
3133
import multiprocessing
3234
import os
3335
from typing import Any, Dict, Optional, Union
3436

35-
from absl import logging
3637
from etils import epath
3738
from tensorflow_datasets.core import dataset_builder
3839
from tensorflow_datasets.core import dataset_info as dataset_info_lib
3940
from tensorflow_datasets.core import download
41+
from tensorflow_datasets.core import example_serializer
42+
from tensorflow_datasets.core import features as feature_lib
4043
from tensorflow_datasets.core import file_adapters
4144
from tensorflow_datasets.core import lazy_imports_lib
4245
from tensorflow_datasets.core import split_builder as split_builder_lib
4346
from tensorflow_datasets.core import splits as splits_lib
4447
from tensorflow_datasets.core.utils import huggingface_utils
48+
from tensorflow_datasets.core.utils import shard_utils
49+
from tensorflow_datasets.core.utils import tqdm_utils
4550
from tensorflow_datasets.core.utils import version as version_lib
4651
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets
4752

48-
_EMPTY_SPLIT_WARNING_MSG = "%s split doesn't have any examples"
49-
5053

5154
def _extract_supervised_keys(hf_info):
5255
if hf_info.supervised_keys is not None:
@@ -57,24 +60,79 @@ def _extract_supervised_keys(hf_info):
5760
return None
5861

5962

60-
def _remove_empty_splits(
61-
splits: Dict[str, split_builder_lib.SplitGenerator]
62-
) -> Dict[str, split_builder_lib.SplitGenerator]:
63-
"""Removes empty splits."""
64-
non_empty_splits = {}
63+
@dataclasses.dataclass(frozen=True)
64+
class _ShardSpec:
65+
"""Spec to write a shard.
66+
67+
Attributes:
68+
path: Shard path.
69+
hf_split: HuggingFace split name.
70+
split: TFDS split name.
71+
start_index: Index of the shard start.
72+
end_index: Index of the shard end.
73+
num_examples: Number of examples in the shard.
74+
shard_split: HuggingFace split for the shard.
75+
"""
76+
77+
path: epath.Path
78+
hf_split: str
79+
split: str
80+
start_index: int
81+
end_index: int
82+
83+
@property
84+
def num_examples(self) -> int:
85+
return self.end_index - self.start_index
86+
87+
@property
88+
def shard_split(self) -> str:
89+
return f'{self.hf_split}[{self.start_index}:{self.end_index}]'
6590

66-
for split, examples_iterable in splits.items():
67-
examples_iterator = iter(examples_iterable)
68-
# ensure the iterator is not empty
69-
try:
70-
first_example = next(examples_iterator)
71-
non_empty_splits[split] = itertools.chain(
72-
[first_example], examples_iterator
73-
)
74-
except StopIteration:
75-
logging.warning(_EMPTY_SPLIT_WARNING_MSG, split)
7691

77-
return non_empty_splits
92+
def _write_shard(
93+
shard_spec: _ShardSpec,
94+
hf_builder,
95+
example_writer,
96+
features: feature_lib.FeaturesDict,
97+
) -> int:
98+
"""Writes shard to the file.
99+
100+
Args:
101+
shard_spec: Shard spec.
102+
hf_builder: HuggingFace dataset builder.
103+
example_writer: Example writer.
104+
features: TFDS features dict.
105+
106+
Returns:
107+
Shard size in bytes.
108+
"""
109+
serialized_info = features.get_serialized_info()
110+
serializer = example_serializer.ExampleSerializer(serialized_info)
111+
num_bytes = 0
112+
113+
def get_serialized_examples_iter():
114+
nonlocal num_bytes
115+
for hf_value in hf_builder.as_dataset(
116+
split=shard_spec.shard_split, run_post_process=False
117+
):
118+
example = huggingface_utils.convert_hf_value(hf_value, features)
119+
serialized_example = serializer.serialize_example(example)
120+
num_bytes += len(serialized_example)
121+
yield serialized_example
122+
123+
example_writer.write(
124+
os.fspath(shard_spec.path),
125+
tqdm_utils.tqdm(
126+
enumerate(get_serialized_examples_iter()),
127+
desc=f'Writing {shard_spec.path} examples...',
128+
unit=' examples',
129+
total=shard_spec.num_examples,
130+
leave=False,
131+
mininterval=1.0,
132+
),
133+
)
134+
135+
return num_bytes
78136

79137

80138
class HuggingfaceDatasetBuilder(
@@ -164,7 +222,7 @@ def _hf_download_and_prepare(self):
164222
def _hf_info(self) -> hf_datasets.DatasetInfo:
165223
return self._hf_builder.info
166224

167-
def _hf_features(self):
225+
def _hf_features(self) -> hf_datasets.Features:
168226
if not self._hf_info.features:
169227
# We need to download and prepare the data to know its features.
170228
self._hf_download_and_prepare()
@@ -185,24 +243,121 @@ def _info(self) -> dataset_info_lib.DatasetInfo:
185243
def _split_generators(
186244
self, dl_manager: download.DownloadManager
187245
) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
188-
del dl_manager
189-
self._hf_download_and_prepare()
190-
ds = self._hf_builder.as_dataset(verification_mode=self._verification_mode)
191-
splits = {
192-
huggingface_utils.convert_hf_name(split): self._generate_examples(data)
193-
for split, data in ds.items()
194-
}
195-
return _remove_empty_splits(splits)
246+
raise NotImplementedError('This method should not be called.')
196247

197248
def _generate_examples(self, data) -> split_builder_lib.SplitGenerator:
198-
convert_example = functools.partial(
199-
huggingface_utils.convert_hf_value, feature=self.info.features
249+
raise NotImplementedError('This method should not be called.')
250+
251+
def _generate_splits(
252+
self,
253+
dl_manager: download.DownloadManager,
254+
download_config: download.DownloadConfig,
255+
) -> Sequence[splits_lib.SplitInfo]:
256+
"""Prepares the dataset by writing to shards directly."""
257+
del dl_manager, download_config # Unused.
258+
self._hf_download_and_prepare()
259+
260+
shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {}
261+
for hf_split, hf_split_info in self._hf_info.splits.items():
262+
split = huggingface_utils.convert_hf_name(hf_split)
263+
shard_specs_by_split[split] = self._compute_shard_specs(
264+
hf_split_info, split
265+
)
266+
267+
shard_sizes_by_split = self._write_shards(shard_specs_by_split)
268+
269+
return [
270+
splits_lib.SplitInfo(
271+
name=split,
272+
shard_lengths=[
273+
shard_spec.num_examples for shard_spec in shard_specs
274+
],
275+
num_bytes=sum(shard_sizes_by_split[split]),
276+
filename_template=self._get_filename_template(split),
277+
)
278+
for split, shard_specs in shard_specs_by_split.items()
279+
]
280+
281+
def _compute_shard_specs(
282+
self, hf_split_info: hf_datasets.SplitInfo, split: str
283+
) -> Sequence[_ShardSpec]:
284+
"""Returns specs for evenly spread shards.
285+
286+
Args:
287+
hf_split_info: HuggingFace split info.
288+
split: TFDS split name.
289+
"""
290+
# HF split size is good enough for estimating the number of shards.
291+
num_shards = shard_utils.ShardConfig.calculate_number_shards(
292+
total_size=hf_split_info.num_bytes,
293+
num_examples=hf_split_info.num_examples,
294+
uses_precise_sharding=False,
295+
)
296+
filename_template = self._get_filename_template(split)
297+
shard_boundaries = shard_utils.get_shard_boundaries(
298+
num_examples=hf_split_info.num_examples, number_of_shards=num_shards
299+
)
300+
301+
prev_shard_boundary = 0
302+
shard_specs: list[_ShardSpec] = []
303+
304+
for shard_index, shard_boundary in enumerate(shard_boundaries):
305+
shard_specs.append(
306+
_ShardSpec(
307+
path=filename_template.sharded_filepath(
308+
shard_index=shard_index, num_shards=len(shard_boundaries)
309+
),
310+
hf_split=hf_split_info.name,
311+
split=split,
312+
start_index=prev_shard_boundary,
313+
end_index=shard_boundary,
314+
)
315+
)
316+
prev_shard_boundary = shard_boundary
317+
318+
return shard_specs
319+
320+
def _write_shards(
321+
self,
322+
shard_specs_by_split: Mapping[str, Sequence[_ShardSpec]],
323+
) -> Mapping[str, Sequence[int]]:
324+
"""Writes shards to files.
325+
326+
Args:
327+
shard_specs_by_split: Shard specs by split name.
328+
329+
Returns:
330+
Shard sizes in bytes.
331+
"""
332+
shard_specs = list(itertools.chain(*shard_specs_by_split.values()))
333+
shard_specs = tqdm_utils.tqdm(
334+
shard_specs,
335+
desc='Writing shards...',
336+
unit=' shards',
337+
total=len(shard_specs),
338+
leave=False,
339+
)
340+
write_shard = functools.partial(
341+
_write_shard,
342+
hf_builder=self._hf_builder,
343+
example_writer=self._example_writer(),
344+
features=self.info.features,
200345
)
346+
201347
if self._tfds_num_proc is None:
202-
yield from enumerate(map(convert_example, data))
348+
shard_sizes = list(map(write_shard, shard_specs))
203349
else:
204350
with multiprocessing.Pool(processes=self._tfds_num_proc) as pool:
205-
yield from enumerate(pool.imap(convert_example, data))
351+
shard_sizes = pool.map(write_shard, shard_specs)
352+
353+
shard_idx = 0
354+
shard_sizes_by_split: dict[str, Sequence[int]] = {}
355+
for split, shard_specs in shard_specs_by_split.items():
356+
shard_sizes_by_split[split] = shard_sizes[
357+
shard_idx : shard_idx + len(shard_specs)
358+
]
359+
shard_idx += len(shard_specs)
360+
return shard_sizes_by_split
206361

207362

208363
def builder(

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,11 @@
1515

1616
from unittest import mock
1717

18-
from absl import logging
1918
import datasets as hf_datasets
2019
import pytest
2120
from tensorflow_datasets.core.dataset_builders import huggingface_dataset_builder
2221

2322

24-
def test_remove_empty_splits():
25-
splits = {'non_empty_split': range(5), 'empty_split': range(0)}
26-
with mock.patch.object(logging, 'log'):
27-
non_empty_splits = huggingface_dataset_builder._remove_empty_splits(splits)
28-
logging.log.assert_called_once_with(
29-
logging.WARNING,
30-
huggingface_dataset_builder._EMPTY_SPLIT_WARNING_MSG,
31-
'empty_split',
32-
)
33-
assert non_empty_splits.keys() == {'non_empty_split'}
34-
assert list(non_empty_splits['non_empty_split']) == list(range(5))
35-
36-
3723
class DummyHuggingfaceBuilder(hf_datasets.GeneratorBasedBuilder):
3824

3925
def _info(self):

0 commit comments

Comments
 (0)