|
19 | 19 |
|
20 | 20 | import abc
|
21 | 21 | import collections
|
22 |
| -from collections.abc import Sequence |
| 22 | +from collections.abc import Iterable, Iterator, Mapping, Sequence |
23 | 23 | import dataclasses
|
24 | 24 | import functools
|
25 | 25 | import inspect
|
26 | 26 | import json
|
27 | 27 | import os
|
28 | 28 | import sys
|
29 |
| -from typing import Any, ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union |
| 29 | +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union |
30 | 30 |
|
31 | 31 | from absl import logging
|
32 | 32 | from etils import epy
|
@@ -1445,6 +1445,17 @@ def builder_configs(cls) -> dict[str, BuilderConfig]:
|
1445 | 1445 | )
|
1446 | 1446 | return config_dict
|
1447 | 1447 |
|
| 1448 | + def _get_filename_template( |
| 1449 | + self, split_name: str |
| 1450 | + ) -> naming.ShardedFileTemplate: |
| 1451 | + """Returns a filename template for the given split.""" |
| 1452 | + return naming.ShardedFileTemplate( |
| 1453 | + split=split_name, |
| 1454 | + dataset_name=self.name, |
| 1455 | + data_dir=self.data_path, |
| 1456 | + filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error |
| 1457 | + ) |
| 1458 | + |
1448 | 1459 |
|
1449 | 1460 | class FileReaderBuilder(DatasetBuilder):
|
1450 | 1461 | """Base class for datasets reading files.
|
@@ -1675,17 +1686,6 @@ def _example_writer(self) -> writer_lib.ExampleWriter:
|
1675 | 1686 | """
|
1676 | 1687 | return writer_lib.ExampleWriter(file_format=self.info.file_format)
|
1677 | 1688 |
|
1678 |
| - def _get_filename_template( |
1679 |
| - self, split_name: str |
1680 |
| - ) -> naming.ShardedFileTemplate: |
1681 |
| - """Returns a filename template for the given split.""" |
1682 |
| - return naming.ShardedFileTemplate( |
1683 |
| - split=split_name, |
1684 |
| - dataset_name=self.name, |
1685 |
| - data_dir=self.data_path, |
1686 |
| - filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error |
1687 |
| - ) |
1688 |
| - |
1689 | 1689 | def _generate_splits(
|
1690 | 1690 | self,
|
1691 | 1691 | dl_manager: download.DownloadManager,
|
@@ -1852,6 +1852,99 @@ def read_tfrecord_beam(
|
1852 | 1852 | )
|
1853 | 1853 |
|
1854 | 1854 |
|
| 1855 | +class ShardBasedBuilder(FileReaderBuilder): |
| 1856 | + """Base class for datasets with data generated shard by shard. |
| 1857 | +
|
| 1858 | + Like `GeneratorBasedBuilder`, this base class can be used to create datasets. |
| 1859 | + However, `ShardBasedBuilder` gives strict control over the number of shards |
| 1860 | + and what data ends up in what shard. |
| 1861 | +
|
| 1862 | + This is useful for datasets where you want to keep the same ordering as the |
| 1863 | + original data source, and/or where you want to keep the same sharding as the |
| 1864 | + original data source. |
| 1865 | +
|
| 1866 | + You have to implement the `_shard_iterators_per_split` method, which returns |
| 1867 | + a mapping from split name to a list of `ExampleGeneratorFn` functions that |
| 1868 | + return an example iterator. The signature of the function is `Callable[[], |
| 1869 | + Iterator[KeyExample]]` where `KeyExample` is a tuple of (key, example) where |
| 1870 | + key is a unique key for the example and example is a dict of features. |
| 1871 | +
|
| 1872 | + Note that a `ExampleGeneratorFn` can also be a class that implements a |
| 1873 | + `__call__` method that returns a `Iterator[KeyExample]`. |
| 1874 | +
|
| 1875 | + Also note that shuffling is not supported. Also, the following fields in |
| 1876 | + `DownloadConfig` are not supported: |
| 1877 | + - `ignore_duplicates` |
| 1878 | + - `max_examples_per_split` |
| 1879 | + - `shard_config` |
| 1880 | + """ |
| 1881 | + |
| 1882 | + def _download_and_prepare( |
| 1883 | + self, |
| 1884 | + dl_manager: download.DownloadManager, |
| 1885 | + download_config: download.DownloadConfig | None = None, |
| 1886 | + ) -> None: |
| 1887 | + download_config = download_config or download.DownloadConfig() |
| 1888 | + |
| 1889 | + split_builder = split_builder_lib.SplitBuilder( |
| 1890 | + split_dict=self.info.splits, |
| 1891 | + features=self.info.features, |
| 1892 | + dataset_size=self.info.dataset_size, |
| 1893 | + beam_options=download_config.beam_options, |
| 1894 | + beam_runner=download_config.beam_runner, |
| 1895 | + example_writer=self._example_writer(), |
| 1896 | + # The following options are ignored by `ShardBasedBuilder`. |
| 1897 | + ignore_duplicates=None, |
| 1898 | + max_examples_per_split=None, |
| 1899 | + shard_config=None, |
| 1900 | + ) |
| 1901 | + |
| 1902 | + shard_iterators_per_split = self._shard_iterators_per_split(dl_manager) |
| 1903 | + split_info_futures = [] |
| 1904 | + for split_name, example_gen_per_shard in shard_iterators_per_split.items(): |
| 1905 | + logging.info("Generating split %s", split_name) |
| 1906 | + split_info_future = split_builder.submit_shard_based_generation( |
| 1907 | + split_name=split_name, |
| 1908 | + example_gen_per_shard=example_gen_per_shard, |
| 1909 | + filename_template=self._get_filename_template(split_name=split_name), |
| 1910 | + ) |
| 1911 | + split_info_futures.append(split_info_future) |
| 1912 | + |
| 1913 | + # Update the info object with the splits. |
| 1914 | + split_infos: list[splits_lib.SplitInfo] = [ |
| 1915 | + future.result() for future in split_info_futures |
| 1916 | + ] |
| 1917 | + split_dict = splits_lib.SplitDict(split_infos) |
| 1918 | + self.info.set_splits(split_dict) |
| 1919 | + |
| 1920 | + @abc.abstractmethod |
| 1921 | + @utils.docs.do_not_doc_in_subclasses |
| 1922 | + @utils.docs.doc_private |
| 1923 | + def _shard_iterators_per_split( |
| 1924 | + self, dl_manager: download.DownloadManager |
| 1925 | + ) -> Mapping[str, Sequence[split_builder_lib.ExampleGeneratorFn]]: |
| 1926 | + """Returns a mapping from split name to example generators per shard. |
| 1927 | +
|
| 1928 | + The example generators are functions with signature `Callable[[], |
| 1929 | + Iterator[KeyExample]]` that take no parameters and return |
| 1930 | + an iterator of tuples of (key, example). The order of the example generators |
| 1931 | + is the order in which the shards will be written. |
| 1932 | +
|
| 1933 | + Args: |
| 1934 | + dl_manager: `tfds.download.DownloadManager` used to download/extract the |
| 1935 | + data. |
| 1936 | + """ |
| 1937 | + raise NotImplementedError() |
| 1938 | + |
| 1939 | + def _example_writer(self) -> writer_lib.ExampleWriter: |
| 1940 | + """Returns an example writer. |
| 1941 | +
|
| 1942 | + If datasets should be written to a custom storage, e.g., a database, then |
| 1943 | + implement a custom `ExampleWriter` and inject it here. |
| 1944 | + """ |
| 1945 | + return writer_lib.ExampleWriter(file_format=self.info.file_format) |
| 1946 | + |
| 1947 | + |
1855 | 1948 | @utils.docs.deprecated
|
1856 | 1949 | class BeamBasedBuilder(GeneratorBasedBuilder):
|
1857 | 1950 | """Beam based Builder.
|
|
0 commit comments