Skip to content

Commit 8223a15

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add a ShardDatasetBuilder that creates shards directly.
In certain cases, users have data available in different shards and they want to keep the same number of shards and in each shard the same order of examples (or they don't care about the ordering). In that case, our current dataset builder classes are much slower than necessary. The `ShardBasedBuilder` allows users to create dataset builders that process source data shard by shard. It can be run with or without Beam. In case of Beam, the resulting Beam pipeline is significantly simpler and therefore faster. PiperOrigin-RevId: 678137550
1 parent fc31737 commit 8223a15

File tree

5 files changed

+365
-24
lines changed

5 files changed

+365
-24
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
import abc
2121
import collections
22-
from collections.abc import Sequence
22+
from collections.abc import Iterable, Iterator, Mapping, Sequence
2323
import dataclasses
2424
import functools
2525
import inspect
2626
import json
2727
import os
2828
import sys
29-
from typing import Any, ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union
29+
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
3030

3131
from absl import logging
3232
from etils import epy
@@ -1445,6 +1445,17 @@ def builder_configs(cls) -> dict[str, BuilderConfig]:
14451445
)
14461446
return config_dict
14471447

1448+
def _get_filename_template(
1449+
self, split_name: str
1450+
) -> naming.ShardedFileTemplate:
1451+
"""Returns a filename template for the given split."""
1452+
return naming.ShardedFileTemplate(
1453+
split=split_name,
1454+
dataset_name=self.name,
1455+
data_dir=self.data_path,
1456+
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
1457+
)
1458+
14481459

14491460
class FileReaderBuilder(DatasetBuilder):
14501461
"""Base class for datasets reading files.
@@ -1675,17 +1686,6 @@ def _example_writer(self) -> writer_lib.ExampleWriter:
16751686
"""
16761687
return writer_lib.ExampleWriter(file_format=self.info.file_format)
16771688

1678-
def _get_filename_template(
1679-
self, split_name: str
1680-
) -> naming.ShardedFileTemplate:
1681-
"""Returns a filename template for the given split."""
1682-
return naming.ShardedFileTemplate(
1683-
split=split_name,
1684-
dataset_name=self.name,
1685-
data_dir=self.data_path,
1686-
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
1687-
)
1688-
16891689
def _generate_splits(
16901690
self,
16911691
dl_manager: download.DownloadManager,
@@ -1852,6 +1852,99 @@ def read_tfrecord_beam(
18521852
)
18531853

18541854

1855+
class ShardBasedBuilder(FileReaderBuilder):
1856+
"""Base class for datasets with data generated shard by shard.
1857+
1858+
Like `GeneratorBasedBuilder`, this base class can be used to create datasets.
1859+
However, `ShardBasedBuilder` gives strict control over the number of shards
1860+
and what data ends up in what shard.
1861+
1862+
This is useful for datasets where you want to keep the same ordering as the
1863+
original data source, and/or where you want to keep the same sharding as the
1864+
original data source.
1865+
1866+
You have to implement the `_shard_iterators_per_split` method, which returns
1867+
a mapping from split name to a list of `ExampleGeneratorFn` functions that
1868+
return an example iterator. The signature of the function is `Callable[[],
1869+
Iterator[KeyExample]]` where `KeyExample` is a tuple of (key, example) where
1870+
key is a unique key for the example and example is a dict of features.
1871+
1872+
Note that a `ExampleGeneratorFn` can also be a class that implements a
1873+
`__call__` method that returns a `Iterator[KeyExample]`.
1874+
1875+
Also note that shuffling is not supported. Also, the following fields in
1876+
`DownloadConfig` are not supported:
1877+
- `ignore_duplicates`
1878+
- `max_examples_per_split`
1879+
- `shard_config`
1880+
"""
1881+
1882+
def _download_and_prepare(
1883+
self,
1884+
dl_manager: download.DownloadManager,
1885+
download_config: download.DownloadConfig | None = None,
1886+
) -> None:
1887+
download_config = download_config or download.DownloadConfig()
1888+
1889+
split_builder = split_builder_lib.SplitBuilder(
1890+
split_dict=self.info.splits,
1891+
features=self.info.features,
1892+
dataset_size=self.info.dataset_size,
1893+
beam_options=download_config.beam_options,
1894+
beam_runner=download_config.beam_runner,
1895+
example_writer=self._example_writer(),
1896+
# The following options are ignored by `ShardBasedBuilder`.
1897+
ignore_duplicates=None,
1898+
max_examples_per_split=None,
1899+
shard_config=None,
1900+
)
1901+
1902+
shard_iterators_per_split = self._shard_iterators_per_split(dl_manager)
1903+
split_info_futures = []
1904+
for split_name, example_gen_per_shard in shard_iterators_per_split.items():
1905+
logging.info("Generating split %s", split_name)
1906+
split_info_future = split_builder.submit_shard_based_generation(
1907+
split_name=split_name,
1908+
example_gen_per_shard=example_gen_per_shard,
1909+
filename_template=self._get_filename_template(split_name=split_name),
1910+
)
1911+
split_info_futures.append(split_info_future)
1912+
1913+
# Update the info object with the splits.
1914+
split_infos: list[splits_lib.SplitInfo] = [
1915+
future.result() for future in split_info_futures
1916+
]
1917+
split_dict = splits_lib.SplitDict(split_infos)
1918+
self.info.set_splits(split_dict)
1919+
1920+
@abc.abstractmethod
1921+
@utils.docs.do_not_doc_in_subclasses
1922+
@utils.docs.doc_private
1923+
def _shard_iterators_per_split(
1924+
self, dl_manager: download.DownloadManager
1925+
) -> Mapping[str, Sequence[split_builder_lib.ExampleGeneratorFn]]:
1926+
"""Returns a mapping from split name to example generators per shard.
1927+
1928+
The example generators are functions with signature `Callable[[],
1929+
Iterator[KeyExample]]` that take no parameters and return
1930+
an iterator of tuples of (key, example). The order of the example generators
1931+
is the order in which the shards will be written.
1932+
1933+
Args:
1934+
dl_manager: `tfds.download.DownloadManager` used to download/extract the
1935+
data.
1936+
"""
1937+
raise NotImplementedError()
1938+
1939+
def _example_writer(self) -> writer_lib.ExampleWriter:
1940+
"""Returns an example writer.
1941+
1942+
If datasets should be written to a custom storage, e.g., a database, then
1943+
implement a custom `ExampleWriter` and inject it here.
1944+
"""
1945+
return writer_lib.ExampleWriter(file_format=self.info.file_format)
1946+
1947+
18551948
@utils.docs.deprecated
18561949
class BeamBasedBuilder(GeneratorBasedBuilder):
18571950
"""Beam based Builder.

tensorflow_datasets/core/dataset_builder_beam_test.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for tensorflow_datasets.core.dataset_builder."""
17-
16+
import functools
1817
import pathlib
1918
from typing import Callable
2019
from unittest import mock
@@ -102,6 +101,31 @@ def _generate_examples(self, examples, num_examples):
102101
return examples
103102

104103

104+
class ShardBuilderBeam(dataset_builder.ShardBasedBuilder):
105+
VERSION = utils.Version('0.0.1')
106+
107+
def _info(self):
108+
return dataset_info.DatasetInfo(
109+
builder=self,
110+
features=features.FeaturesDict({'x': np.int64}),
111+
)
112+
113+
def _shard_iterators_per_split(self, dl_manager):
114+
del dl_manager
115+
116+
def gen_examples(start: int, end: int):
117+
for i in range(start, end):
118+
yield i, {'x': i}
119+
120+
return {
121+
'train': [
122+
functools.partial(gen_examples, start=0, end=10),
123+
functools.partial(gen_examples, start=10, end=20),
124+
],
125+
'test': [functools.partial(gen_examples, start=100, end=110)],
126+
}
127+
128+
105129
def _gen_example(x):
106130
return (
107131
x,
@@ -198,6 +222,26 @@ def _assert_values_equal(nested_lhs, nested_rhs):
198222
np.testing.assert_array_equal(lhs, rhs)
199223

200224

225+
@pytest.mark.parametrize(
226+
'make_dl_config',
227+
[
228+
make_default_config,
229+
],
230+
)
231+
def test_beam_shard_builder_dataset(
232+
tmp_path: pathlib.Path,
233+
make_dl_config: Callable[[], download.DownloadConfig],
234+
):
235+
builder = ShardBuilderBeam(data_dir=tmp_path, version='0.0.1')
236+
builder.download_and_prepare(
237+
file_format='array_record', download_config=make_dl_config()
238+
)
239+
actual_train_data = list(builder.as_data_source(split='train'))
240+
assert actual_train_data == [{'x': i} for i in range(20)]
241+
actual_test_data = list(builder.as_data_source(split='test'))
242+
assert actual_test_data == [{'x': i} for i in range(100, 110)]
243+
244+
201245
def test_read_tfrecord_beam():
202246
builder = DummyBeamDataset()
203247
with mock.patch.object(

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
"""Tests for tensorflow_datasets.core.dataset_builder."""
1717

18+
from collections.abc import Iterator, Mapping, Sequence
1819
import dataclasses
20+
import functools
1921
import os
2022
import tempfile
2123
from unittest import mock
@@ -37,9 +39,11 @@
3739
from tensorflow_datasets.core import load
3840
from tensorflow_datasets.core import naming
3941
from tensorflow_datasets.core import read_only_builder
42+
from tensorflow_datasets.core import split_builder
4043
from tensorflow_datasets.core import splits as splits_lib
4144
from tensorflow_datasets.core import utils
4245
from tensorflow_datasets.core.data_sources import array_record
46+
from tensorflow_datasets.core.download import download_manager
4347
from tensorflow_datasets.core.utils import file_utils
4448
from tensorflow_datasets.core.utils import read_config as read_config_lib
4549
from tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1 import dummy_ds_1_dataset_builder
@@ -147,6 +151,50 @@ def _split_generators(self, _):
147151
return {"all": self._generate_examples(range(5))}
148152

149153

154+
class ShardBuilder(dataset_builder.ShardBasedBuilder):
155+
VERSION = utils.Version("0.0.1")
156+
BUILDER_CONFIGS = [DummyBuilderConfig(name="cfg1")]
157+
158+
def _info(self):
159+
return dataset_info.DatasetInfo(
160+
builder=self,
161+
features=features.FeaturesDict({"x": np.int64}),
162+
)
163+
164+
def _shard_iterators_per_split(
165+
self, dl_manager: download_manager.DownloadManager
166+
) -> Mapping[str, Sequence[Iterator[split_builder.KeyExample]]]:
167+
del dl_manager
168+
169+
def gen_examples(
170+
start: int, end: int
171+
) -> Iterator[split_builder.KeyExample]:
172+
for i in range(start, end):
173+
yield i, {"x": i}
174+
175+
return {
176+
# train split has 2 shards
177+
"train": [
178+
functools.partial(gen_examples, start=0, end=10),
179+
functools.partial(gen_examples, start=10, end=20),
180+
],
181+
"test": [functools.partial(gen_examples, start=100, end=110)],
182+
}
183+
184+
185+
class ShardBuilderTest(testing.TestCase):
186+
187+
def test_download_and_prepare(self):
188+
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
189+
builder = ShardBuilder(data_dir=tmp_dir, config="cfg1", version="0.0.1")
190+
builder.download_and_prepare(file_format="array_record")
191+
actual_data = list(builder.as_data_source(split="train"))
192+
self.assertEqual(
193+
actual_data,
194+
[{"x": i} for i in range(20)],
195+
)
196+
197+
150198
class GetBuilderDatadirPathTest(testing.TestCase):
151199

152200
def test_builder_data_dir_path_is_correct(self):

0 commit comments

Comments
 (0)