18
18
from __future__ import annotations
19
19
20
20
import abc
21
- from collections .abc import Iterable
21
+ from collections .abc import Iterable , Sequence
22
22
import dataclasses
23
23
import functools
24
24
import itertools
@@ -123,7 +123,7 @@ def __post_init__(self):
123
123
def get_available_shards (
124
124
self ,
125
125
data_dir : epath .Path | None = None ,
126
- file_format : file_adapters .FileFormat | None = None ,
126
+ file_format : str | file_adapters .FileFormat | None = None ,
127
127
strict_matching : bool = True ,
128
128
) -> list [epath .Path ]:
129
129
"""Returns the list of shards that are present in the data dir.
@@ -140,6 +140,7 @@ def get_available_shards(
140
140
"""
141
141
if filename_template := self .filename_template :
142
142
if file_format :
143
+ file_format = file_adapters .FileFormat .from_value (file_format )
143
144
filename_template = filename_template .replace (
144
145
filetype_suffix = file_format .file_suffix
145
146
)
@@ -250,7 +251,9 @@ def replace(self, **kwargs: Any) -> SplitInfo:
250
251
"""Returns a copy of the `SplitInfo` with updated attributes."""
251
252
return dataclasses .replace (self , ** kwargs )
252
253
253
- def file_spec (self , file_format : file_adapters .FileFormat ) -> str :
254
+ def file_spec (
255
+ self , file_format : str | file_adapters .FileFormat
256
+ ) -> str | None :
254
257
"""Returns the file spec of the split for the given file format.
255
258
256
259
A file spec is the full path with sharded notation, e.g.,
@@ -259,6 +262,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
259
262
Args:
260
263
file_format: the file format for which to create the file spec for.
261
264
"""
265
+ file_format = file_adapters .FileFormat .from_value (file_format )
262
266
if filename_template := self .filename_template :
263
267
if filename_template .filetype_suffix != file_format .file_suffix :
264
268
raise ValueError (
@@ -268,9 +272,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
268
272
return filename_template .sharded_filepaths_pattern (
269
273
num_shards = self .num_shards
270
274
)
271
- raise ValueError (
272
- f'Could not get filename template for split from split info: { self } .'
273
- )
275
+ return None
274
276
275
277
276
278
@dataclasses .dataclass (eq = False , frozen = True )
@@ -425,7 +427,7 @@ def __repr__(self) -> str:
425
427
if typing .TYPE_CHECKING :
426
428
# For type checking, `tfds.Split` is an alias for `str` with additional
427
429
# `.TRAIN`, `.TEST`,... attributes. All strings are valid split type.
428
- Split = Union [ Split , str ]
430
+ Split = Split | str
429
431
430
432
431
433
class SplitDict (utils .NonMutableDict [str , SplitInfo ]):
@@ -438,7 +440,7 @@ def __init__(
438
440
# TODO(b/216470058): remove this parameter
439
441
dataset_name : str | None = None , # deprecated, please don't use
440
442
):
441
- super (SplitDict , self ).__init__ (
443
+ super ().__init__ (
442
444
{split_info .name : split_info for split_info in split_infos },
443
445
error_msg = 'Split {key} already present' ,
444
446
)
@@ -457,7 +459,7 @@ def __getitem__(self, key) -> SplitInfo | SubSplitInfo:
457
459
)
458
460
# 1st case: The key exists: `info.splits['train']`
459
461
elif str (key ) in self .keys ():
460
- return super (SplitDict , self ).__getitem__ (str (key ))
462
+ return super ().__getitem__ (str (key ))
461
463
# 2nd case: Uses instructions: `info.splits['train[50%]']`
462
464
else :
463
465
instructions = _make_file_instructions (
@@ -543,7 +545,7 @@ def _file_instructions_for_split(
543
545
544
546
545
547
def _make_file_instructions (
546
- split_infos : list [SplitInfo ],
548
+ split_infos : Sequence [SplitInfo ],
547
549
instruction : SplitArg ,
548
550
) -> list [shard_utils .FileInstruction ]:
549
551
"""Returns file instructions by applying the given instruction on the given splits.
@@ -587,7 +589,7 @@ class AbstractSplit(abc.ABC):
587
589
"""
588
590
589
591
@classmethod
590
- def from_spec (cls , spec : SplitArg ) -> ' AbstractSplit' :
592
+ def from_spec (cls , spec : SplitArg ) -> AbstractSplit :
591
593
"""Creates a ReadInstruction instance out of a string spec.
592
594
593
595
Args:
@@ -632,7 +634,7 @@ def to_absolute(self, split_infos) -> list[_AbsoluteInstruction]:
632
634
"""
633
635
raise NotImplementedError
634
636
635
- def __add__ (self , other : Union [ str , ' AbstractSplit' ] ) -> ' AbstractSplit' :
637
+ def __add__ (self , other : str | AbstractSplit ) -> AbstractSplit :
636
638
"""Sum of 2 splits."""
637
639
if not isinstance (other , (str , AbstractSplit )):
638
640
raise TypeError (f'Adding split { self !r} with non-split value: { other !r} ' )
0 commit comments