|
29 | 29 | from typing import Any, ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union
|
30 | 30 |
|
31 | 31 | from absl import logging
|
32 |
| -from etils import epath |
33 |
| -import importlib_resources |
34 |
| -from tensorflow_datasets.core import constants |
35 |
| -from tensorflow_datasets.core import dataset_info |
36 |
| -from tensorflow_datasets.core import dataset_metadata |
37 |
| -from tensorflow_datasets.core import decode |
38 |
| -from tensorflow_datasets.core import download |
39 |
| -from tensorflow_datasets.core import file_adapters |
40 |
| -from tensorflow_datasets.core import lazy_imports_lib |
41 |
| -from tensorflow_datasets.core import logging as tfds_logging |
42 |
| -from tensorflow_datasets.core import naming |
43 |
| -from tensorflow_datasets.core import reader as reader_lib |
44 |
| -from tensorflow_datasets.core import registered |
45 |
| -from tensorflow_datasets.core import split_builder as split_builder_lib |
46 |
| -from tensorflow_datasets.core import splits as splits_lib |
47 |
| -from tensorflow_datasets.core import tf_compat |
48 |
| -from tensorflow_datasets.core import units |
49 |
| -from tensorflow_datasets.core import utils |
50 |
| -from tensorflow_datasets.core import writer as writer_lib |
51 |
| -from tensorflow_datasets.core.data_sources import array_record |
52 |
| -from tensorflow_datasets.core.data_sources import parquet |
53 |
| -from tensorflow_datasets.core.proto import dataset_info_pb2 |
54 |
| -from tensorflow_datasets.core.utils import file_utils |
55 |
| -from tensorflow_datasets.core.utils import gcs_utils |
56 |
| -from tensorflow_datasets.core.utils import read_config as read_config_lib |
57 |
| -from tensorflow_datasets.core.utils import type_utils |
| 32 | +from etils import epy |
58 | 33 | from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
|
59 | 34 | from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
|
60 | 35 | from tensorflow_datasets.core.utils.lazy_imports_utils import tree
|
61 |
| -import termcolor |
62 | 36 |
|
| 37 | +with epy.lazy_imports(): |
| 38 | + # pylint: disable=g-import-not-at-top |
| 39 | + from etils import epath |
| 40 | + import importlib_resources |
| 41 | + import termcolor |
| 42 | + |
| 43 | + from tensorflow_datasets.core import constants |
| 44 | + from tensorflow_datasets.core import dataset_info |
| 45 | + from tensorflow_datasets.core import dataset_metadata |
| 46 | + from tensorflow_datasets.core import decode |
| 47 | + from tensorflow_datasets.core import download |
| 48 | + from tensorflow_datasets.core import file_adapters |
| 49 | + from tensorflow_datasets.core import lazy_imports_lib |
| 50 | + from tensorflow_datasets.core import logging as tfds_logging |
| 51 | + from tensorflow_datasets.core import naming |
| 52 | + from tensorflow_datasets.core import reader as reader_lib |
| 53 | + from tensorflow_datasets.core import registered |
| 54 | + from tensorflow_datasets.core import split_builder as split_builder_lib |
| 55 | + from tensorflow_datasets.core import splits as splits_lib |
| 56 | + from tensorflow_datasets.core import tf_compat |
| 57 | + from tensorflow_datasets.core import units |
| 58 | + from tensorflow_datasets.core import utils |
| 59 | + from tensorflow_datasets.core import writer as writer_lib |
| 60 | + from tensorflow_datasets.core.data_sources import array_record |
| 61 | + from tensorflow_datasets.core.data_sources import parquet |
| 62 | + from tensorflow_datasets.core.proto import dataset_info_pb2 |
| 63 | + from tensorflow_datasets.core.utils import file_utils |
| 64 | + from tensorflow_datasets.core.utils import gcs_utils |
| 65 | + from tensorflow_datasets.core.utils import read_config as read_config_lib |
| 66 | + from tensorflow_datasets.core.utils import type_utils |
| 67 | + # pylint: enable=g-import-not-at-top |
63 | 68 |
|
64 | 69 | ListOrTreeOrElem = type_utils.ListOrTreeOrElem
|
65 | 70 | Tree = type_utils.Tree
|
@@ -726,6 +731,17 @@ def download_and_prepare(
|
726 | 731 |
|
727 | 732 | self._log_download_done()
|
728 | 733 |
|
| 734 | + # Execute post download and prepare hook if it exists. |
| 735 | + self._post_download_and_prepare_hook() |
| 736 | + |
| 737 | + |
| 738 | + def _post_download_and_prepare_hook(self) -> None: |
| 739 | + """Hook to be executed after download and prepare. |
| 740 | +
|
| 741 | + Override this in custom dataset builders to execute custom logic after |
| 742 | + download and prepare. |
| 743 | + """ |
| 744 | + pass |
729 | 745 |
|
730 | 746 | def _update_dataset_info(self) -> None:
|
731 | 747 | """Updates the `dataset_info.json` file in the dataset dir."""
|
@@ -767,33 +783,56 @@ def as_data_source(
|
767 | 783 | if split is None:
|
768 | 784 | split = {s: s for s in self.info.splits}
|
769 | 785 |
|
770 |
| - # Create a dataset for each of the given splits |
771 |
| - def build_single_data_source( |
772 |
| - split: str, |
773 |
| - ) -> Sequence[Any]: |
774 |
| - file_format = self.info.file_format |
775 |
| - if file_format == file_adapters.FileFormat.ARRAY_RECORD: |
776 |
| - return array_record.ArrayRecordDataSource( |
777 |
| - self.info, |
778 |
| - split=split, |
779 |
| - decoders=decoders, |
| 786 | + info = self.info |
| 787 | + |
| 788 | + random_access_formats = file_adapters.FileFormat.with_random_access() |
| 789 | + random_access_formats_msg = " or ".join( |
| 790 | + [f.value for f in random_access_formats] |
| 791 | + ) |
| 792 | + unsupported_format_msg = ( |
| 793 | + f"Random access data source for file format {info.file_format} is" |
| 794 | + " not supported. Can you try to run download_and_prepare with" |
| 795 | + f" file_format set to one of: {random_access_formats_msg}?" |
| 796 | + ) |
| 797 | + |
| 798 | + if info.file_format is None and not info.alternative_file_formats: |
| 799 | + raise ValueError( |
| 800 | + "Dataset info file format is not set! For random access, one of the" |
| 801 | + f" following formats is required: {random_access_formats_msg}" |
| 802 | + ) |
| 803 | + |
| 804 | + if ( |
| 805 | + info.file_format is None |
| 806 | + or info.file_format not in random_access_formats |
| 807 | + ): |
| 808 | + available_formats = set(info.alternative_file_formats) |
| 809 | + suitable_formats = available_formats.intersection(random_access_formats) |
| 810 | + if suitable_formats: |
| 811 | + chosen_format = suitable_formats.pop() |
| 812 | + logging.info( |
| 813 | + "Found random access formats: %s. Chose to use %s. Overriding file" |
| 814 | + " format in the dataset info.", |
| 815 | + ", ".join([f.name for f in suitable_formats]), |
| 816 | + chosen_format, |
780 | 817 | )
|
781 |
| - elif file_format == file_adapters.FileFormat.PARQUET: |
782 |
| - return parquet.ParquetDataSource( |
783 |
| - self.info, |
784 |
| - split=split, |
785 |
| - decoders=decoders, |
| 818 | + # Change the dataset info to read from a random access format. |
| 819 | + info.set_file_format( |
| 820 | + chosen_format, override=True, override_if_initialized=True |
786 | 821 | )
|
787 | 822 | else:
|
788 |
| - args = [ |
789 |
| - f"`file_format='{file_format.value}'`" |
790 |
| - for file_format in file_adapters.FileFormat.with_random_access() |
791 |
| - ] |
792 |
| - raise NotImplementedError( |
793 |
| - f"Random access data source for file format {file_format} is not" |
794 |
| - " supported. Can you try to run download_and_prepare with" |
795 |
| - f" {' or '.join(args)}?" |
796 |
| - ) |
| 823 | + raise NotImplementedError(unsupported_format_msg) |
| 824 | + |
| 825 | + # Create a dataset for each of the given splits |
| 826 | + def build_single_data_source(split: str) -> Sequence[Any]: |
| 827 | + match info.file_format: |
| 828 | + case file_adapters.FileFormat.ARRAY_RECORD: |
| 829 | + return array_record.ArrayRecordDataSource( |
| 830 | + info, split=split, decoders=decoders |
| 831 | + ) |
| 832 | + case file_adapters.FileFormat.PARQUET: |
| 833 | + return parquet.ParquetDataSource(info, split=split, decoders=decoders) |
| 834 | + case _: |
| 835 | + raise NotImplementedError(unsupported_format_msg) |
797 | 836 |
|
798 | 837 | all_ds = tree.map_structure(build_single_data_source, split)
|
799 | 838 | return all_ds
|
|
0 commit comments