|
21 | 21 | import dataclasses
|
22 | 22 | import difflib
|
23 | 23 | import json
|
24 |
| -import os |
25 | 24 | import posixpath
|
26 | 25 | import re
|
27 | 26 | import textwrap
|
28 | 27 | import typing
|
29 |
| -from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Type, Union |
| 28 | +from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Type |
30 | 29 |
|
31 | 30 | from absl import logging
|
| 31 | +from etils import epath |
32 | 32 | from tensorflow_datasets.core import community
|
33 | 33 | from tensorflow_datasets.core import constants
|
34 | 34 | from tensorflow_datasets.core import dataset_builder
|
@@ -225,7 +225,7 @@ def get_dataset_repr() -> str:
|
225 | 225 |
|
226 | 226 |
|
227 | 227 | def _try_load_from_files_first(
|
228 |
| - cls: Optional[Type[dataset_builder.DatasetBuilder]], |
| 228 | + cls: Type[dataset_builder.DatasetBuilder] | None, |
229 | 229 | **builder_kwargs: Any,
|
230 | 230 | ) -> bool:
|
231 | 231 | """Returns True if files should be used rather than code."""
|
@@ -487,8 +487,8 @@ def dataset_collection(
|
487 | 487 |
|
488 | 488 | def _fetch_builder(
|
489 | 489 | name: str,
|
490 |
| - data_dir: Union[None, str, os.PathLike], # pylint: disable=g-bare-generic |
491 |
| - builder_kwargs: Optional[Dict[str, Any]], |
| 490 | + data_dir: epath.PathLike | None, |
| 491 | + builder_kwargs: dict[str, Any] | None, |
492 | 492 | try_gcs: bool,
|
493 | 493 | ) -> dataset_builder.DatasetBuilder:
|
494 | 494 | """Fetches the `tfds.core.DatasetBuilder` by name."""
|
@@ -521,18 +521,18 @@ def _download_and_prepare_builder(
|
521 | 521 | def load(
|
522 | 522 | name: str,
|
523 | 523 | *,
|
524 |
| - split: Optional[Tree[splits_lib.SplitArg]] = None, |
525 |
| - data_dir: Union[None, str, os.PathLike] = None, # pylint: disable=g-bare-generic |
526 |
| - batch_size: Optional[int] = None, |
| 524 | + split: Tree[splits_lib.SplitArg] | None = None, |
| 525 | + data_dir: epath.PathLike | None = None, |
| 526 | + batch_size: int | None = None, |
527 | 527 | shuffle_files: bool = False,
|
528 | 528 | download: bool = True,
|
529 | 529 | as_supervised: bool = False,
|
530 |
| - decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None, |
531 |
| - read_config: Optional[read_config_lib.ReadConfig] = None, |
| 530 | + decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None, |
| 531 | + read_config: read_config_lib.ReadConfig | None = None, |
532 | 532 | with_info: bool = False,
|
533 |
| - builder_kwargs: Optional[Dict[str, Any]] = None, |
534 |
| - download_and_prepare_kwargs: Optional[Dict[str, Any]] = None, |
535 |
| - as_dataset_kwargs: Optional[Dict[str, Any]] = None, |
| 533 | + builder_kwargs: dict[str, Any] | None = None, |
| 534 | + download_and_prepare_kwargs: dict[str, Any] | None = None, |
| 535 | + as_dataset_kwargs: dict[str, Any] | None = None, |
536 | 536 | try_gcs: bool = False,
|
537 | 537 | ):
|
538 | 538 | # pylint: disable=line-too-long
|
@@ -677,37 +677,49 @@ def load(
|
677 | 677 |
|
678 | 678 |
|
679 | 679 | def _set_file_format_for_data_source(
|
680 |
| - builder_kwargs: Optional[Dict[str, Any]], |
681 |
| -) -> Dict[str, Any]: |
| 680 | + data_dir: epath.PathLike | None, |
| 681 | + builder_kwargs: dict[str, Any] | None, |
| 682 | +) -> dict[str, Any]: |
682 | 683 | """Normalizes file format in builder_kwargs for `tfds.data_source`."""
|
683 | 684 | if builder_kwargs is None:
|
684 | 685 | builder_kwargs = {}
|
685 |
| - file_format = builder_kwargs.get( |
686 |
| - 'file_format', file_adapters.FileFormat.ARRAY_RECORD |
687 |
| - ) |
| 686 | + # If the user specified a builder_kwargs or a data_dir, we don't want to |
| 687 | + # overwrite it. |
| 688 | + if builder_kwargs or data_dir: |
| 689 | + return builder_kwargs |
| 690 | + return {'file_format': file_adapters.FileFormat.ARRAY_RECORD} |
| 691 | + |
| 692 | + |
| 693 | +def _validate_file_format_for_data_source( |
| 694 | + builder_kwargs: dict[str, Any], |
| 695 | +) -> None: |
| 696 | + """Validates whether the file format supports random access.""" |
| 697 | + file_format = builder_kwargs.get('file_format') |
| 698 | + if not file_format: |
| 699 | + # We don't raise an error because we let TFDS handle the default (e.g., |
| 700 | + # when loading a dataset from files that support random access). |
| 701 | + return |
688 | 702 | file_format = file_adapters.FileFormat.from_value(file_format)
|
689 |
| - if file_format != file_adapters.FileFormat.ARRAY_RECORD: |
| 703 | + if file_format not in file_adapters.FileFormat.with_random_access(): |
690 | 704 | raise NotImplementedError(
|
691 | 705 | f'No random access data source for file format {file_format}. Please,'
|
692 | 706 | ' use `tfds.data_source(...,'
|
693 | 707 | ' builder_kwargs={"file_format":'
|
694 | 708 | f' {file_adapters.FileFormat.ARRAY_RECORD}}})` instead.'
|
695 | 709 | )
|
696 |
| - builder_kwargs['file_format'] = file_format |
697 |
| - return builder_kwargs |
698 | 710 |
|
699 | 711 |
|
700 | 712 | @tfds_logging.data_source()
|
701 | 713 | def data_source(
|
702 | 714 | name: str,
|
703 | 715 | *,
|
704 |
| - split: Optional[Tree[splits_lib.SplitArg]] = None, |
705 |
| - data_dir: Union[None, str, os.PathLike] = None, # pylint: disable=g-bare-generic |
| 716 | + split: Tree[splits_lib.SplitArg] | None = None, |
| 717 | + data_dir: epath.PathLike | None = None, |
706 | 718 | download: bool = True,
|
707 |
| - decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None, |
| 719 | + decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None, |
708 | 720 | deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
|
709 |
| - builder_kwargs: Optional[Dict[str, Any]] = None, |
710 |
| - download_and_prepare_kwargs: Optional[Dict[str, Any]] = None, |
| 721 | + builder_kwargs: dict[str, Any] | None = None, |
| 722 | + download_and_prepare_kwargs: dict[str, Any] | None = None, |
711 | 723 | try_gcs: bool = False,
|
712 | 724 | ) -> type_utils.ListOrTreeOrElem[Sequence[Any]]:
|
713 | 725 | """Gets a data source from the named dataset.
|
@@ -805,7 +817,8 @@ def data_source(
|
805 | 817 | `Sequence` if `split`,
|
806 | 818 | `dict<key: tfds.Split, value: Sequence>` otherwise.
|
807 | 819 | """ # fmt:skip
|
808 |
| - builder_kwargs = _set_file_format_for_data_source(builder_kwargs) |
| 820 | + builder_kwargs = _set_file_format_for_data_source(data_dir, builder_kwargs) |
| 821 | + _validate_file_format_for_data_source(builder_kwargs) |
809 | 822 | dbuilder = _fetch_builder(
|
810 | 823 | name,
|
811 | 824 | data_dir,
|
|
0 commit comments