@@ -563,9 +563,9 @@ def get_reference(
563
563
data_dir = self .data_dir_root ,
564
564
)
565
565
566
- def get_file_spec (self , split : str ) -> str :
566
+ def get_file_spec (self , split : str ) -> str | None :
567
567
"""Returns the file spec of the split."""
568
- split_info : splits_lib . SplitInfo = self .info .splits [split ]
568
+ split_info = self .info .splits [split ]
569
569
return split_info .file_spec (self .info .file_format )
570
570
571
571
def is_prepared (self ) -> bool :
@@ -815,6 +815,7 @@ def as_data_source(
815
815
* ,
816
816
decoders : TreeDict [decode .partial_decode .DecoderArg ] | None = None ,
817
817
deserialize_method : decode .DeserializeMethod = decode .DeserializeMethod .DESERIALIZE_AND_DECODE ,
818
+ file_format : str | file_adapters .FileFormat | None = None ,
818
819
) -> ListOrTreeOrElem [Sequence [Any ]]:
819
820
"""Constructs an `ArrayRecordDataSource`.
820
821
@@ -833,6 +834,9 @@ def as_data_source(
833
834
the features. Decoding is only supported if the examples are tf
834
835
examples. Note that if the deserialize_method method is other than
835
836
PARSE_AND_DECODE, then the `decoders` argument is ignored.
837
+ file_format: if the dataset is stored in multiple file formats, then this
838
+ can be used to specify which format to use. If not provided, we will
839
+ default to the first available format.
836
840
837
841
Returns:
838
842
`Sequence` if `split`,
@@ -868,22 +872,31 @@ def as_data_source(
868
872
"Dataset info file format is not set! For random access, one of the"
869
873
f" following formats is required: { random_access_formats_msg } "
870
874
)
871
-
872
875
suitable_formats = available_formats .intersection (random_access_formats )
873
- if suitable_formats :
876
+ if not suitable_formats :
877
+ raise NotImplementedError (unsupported_format_msg )
878
+
879
+ if file_format is not None :
880
+ file_format = file_adapters .FileFormat .from_value (file_format )
881
+ if file_format not in suitable_formats :
882
+ raise ValueError (
883
+ f"Requested file format { file_format } is not available for this"
884
+ f" dataset. Available formats: { available_formats } "
885
+ )
886
+ chosen_format = file_format
887
+ else :
874
888
chosen_format = suitable_formats .pop ()
875
889
logging .info (
876
890
"Found random access formats: %s. Chose to use %s. Overriding file"
877
891
" format in the dataset info." ,
878
892
", " .join ([f .name for f in suitable_formats ]),
879
893
chosen_format ,
880
894
)
881
- # Change the dataset info to read from a random access format.
882
- info .set_file_format (
883
- chosen_format , override = True , override_if_initialized = True
884
- )
885
- else :
886
- raise NotImplementedError (unsupported_format_msg )
895
+
896
+ # Change the dataset info to read from a random access format.
897
+ info .set_file_format (
898
+ chosen_format , override = True , override_if_initialized = True
899
+ )
887
900
888
901
# Create a dataset for each of the given splits
889
902
def build_single_data_source (split : str ) -> Sequence [Any ]:
@@ -924,6 +937,7 @@ def as_dataset(
924
937
decoders : TreeDict [decode .partial_decode .DecoderArg ] | None = None ,
925
938
read_config : read_config_lib .ReadConfig | None = None ,
926
939
as_supervised : bool = False ,
940
+ file_format : str | file_adapters .FileFormat | None = None ,
927
941
):
928
942
# pylint: disable=line-too-long
929
943
"""Constructs a `tf.data.Dataset`.
@@ -993,6 +1007,9 @@ def as_dataset(
993
1007
a 2-tuple structure `(input, label)` according to
994
1008
`builder.info.supervised_keys`. If `False`, the default, the returned
995
1009
`tf.data.Dataset` will have a dictionary with all the features.
1010
+ file_format: if the dataset is stored in multiple file formats, then this
1011
+ argument can be used to specify the file format to load. If not
1012
+ specified, the default file format is used.
996
1013
997
1014
Returns:
998
1015
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
@@ -1026,6 +1043,7 @@ def as_dataset(
1026
1043
decoders = decoders ,
1027
1044
read_config = read_config ,
1028
1045
as_supervised = as_supervised ,
1046
+ file_format = file_format ,
1029
1047
)
1030
1048
all_ds = tree .map_structure (build_single_dataset , split )
1031
1049
return all_ds
@@ -1038,19 +1056,29 @@ def _build_single_dataset(
1038
1056
decoders : TreeDict [decode .partial_decode .DecoderArg ] | None ,
1039
1057
read_config : read_config_lib .ReadConfig ,
1040
1058
as_supervised : bool ,
1059
+ file_format : str | file_adapters .FileFormat | None = None ,
1041
1060
) -> tf .data .Dataset :
1042
1061
"""as_dataset for a single split."""
1043
1062
wants_full_dataset = batch_size == - 1
1044
1063
if wants_full_dataset :
1045
1064
batch_size = self .info .splits .total_num_examples or sys .maxsize
1046
1065
1066
+ if file_format is not None :
1067
+ file_format = file_adapters .FileFormat .from_value (file_format )
1068
+
1047
1069
# Build base dataset
1048
- ds = self ._as_dataset (
1049
- split = split ,
1050
- shuffle_files = shuffle_files ,
1051
- decoders = decoders ,
1052
- read_config = read_config ,
1053
- )
1070
+ as_dataset_kwargs = {
1071
+ "split" : split ,
1072
+ "shuffle_files" : shuffle_files ,
1073
+ "decoders" : decoders ,
1074
+ "read_config" : read_config ,
1075
+ }
1076
+ # Not all dataset builder classes support file_format, so only pass it if
1077
+ # it's supported.
1078
+ if "file_format" in inspect .signature (self ._as_dataset ).parameters :
1079
+ as_dataset_kwargs ["file_format" ] = file_format
1080
+ ds = self ._as_dataset (** as_dataset_kwargs )
1081
+
1054
1082
# Auto-cache small datasets which are small enough to fit in memory.
1055
1083
if self ._should_cache_ds (
1056
1084
split = split , shuffle_files = shuffle_files , read_config = read_config
@@ -1235,6 +1263,7 @@ def _as_dataset(
1235
1263
decoders : TreeDict [decode .partial_decode .DecoderArg ] | None = None ,
1236
1264
read_config : read_config_lib .ReadConfig | None = None ,
1237
1265
shuffle_files : bool = False ,
1266
+ file_format : str | file_adapters .FileFormat | None = None ,
1238
1267
) -> tf .data .Dataset :
1239
1268
"""Constructs a `tf.data.Dataset`.
1240
1269
@@ -1250,6 +1279,9 @@ def _as_dataset(
1250
1279
read_config: `tfds.ReadConfig`
1251
1280
shuffle_files: `bool`, whether to shuffle the input files. Optional,
1252
1281
defaults to `False`.
1282
+ file_format: if the dataset is stored in multiple file formats, then this
1283
+ argument can be used to specify the file format to load. If not
1284
+ specified, the default file format is used.
1253
1285
1254
1286
Returns:
1255
1287
`tf.data.Dataset`
@@ -1487,6 +1519,10 @@ def __init__(
1487
1519
1488
1520
@functools .cached_property
1489
1521
def _example_specs (self ):
1522
+ if self .info .features is None :
1523
+ raise ValueError (
1524
+ f"Features are not set for dataset { self .name } in { self .data_dir } !"
1525
+ )
1490
1526
return self .info .features .get_serialized_info ()
1491
1527
1492
1528
def _as_dataset ( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
@@ -1495,6 +1531,7 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
1495
1531
decoders : TreeDict [decode .partial_decode .DecoderArg ] | None ,
1496
1532
read_config : read_config_lib .ReadConfig ,
1497
1533
shuffle_files : bool ,
1534
+ file_format : file_adapters .FileFormat | None = None ,
1498
1535
) -> tf .data .Dataset :
1499
1536
# Partial decoding
1500
1537
# TODO(epot): Should be moved inside `features.decode_example`
@@ -1508,10 +1545,15 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
1508
1545
example_specs = self ._example_specs
1509
1546
decoders = decoders # pylint: disable=self-assigning-variable
1510
1547
1548
+ if features is None :
1549
+ raise ValueError (
1550
+ f"Features are not set for dataset { self .name } in { self .data_dir } !"
1551
+ )
1552
+
1511
1553
reader = reader_lib .Reader (
1512
1554
self .data_dir ,
1513
1555
example_specs = example_specs ,
1514
- file_format = self .info .file_format ,
1556
+ file_format = file_format or self .info .file_format ,
1515
1557
)
1516
1558
decode_fn = functools .partial (features .decode_example , decoders = decoders )
1517
1559
return reader .read (
0 commit comments