17
17
18
18
from __future__ import annotations
19
19
20
+ from collections .abc import Mapping , MutableMapping
20
21
import dataclasses
21
22
import functools
22
23
import os
23
24
import re
24
25
import textwrap
25
- from typing import Any , Dict , List , Mapping , MutableMapping , Optional , Tuple , Union
26
+ from typing import Any
26
27
27
28
from etils import epath
28
29
from tensorflow_datasets .core .utils import py_utils
63
64
_first_cap_re = re .compile ('(.)([A-Z][a-z0-9]+)' )
64
65
_all_cap_re = re .compile ('([a-z0-9])([A-Z])' )
65
66
66
- Value = Union [ str , int , float , bool ]
67
+ Value = str | int | float | bool
67
68
68
69
69
70
@dataclasses .dataclass (eq = True , order = True , frozen = True )
70
71
class DatasetName :
71
72
"""Dataset namespace+name."""
72
73
73
- namespace : Optional [ str ]
74
+ namespace : str | None
74
75
name : str
75
76
76
77
def __init__ (
77
78
self ,
78
- namespace_name : Optional [ str ] = None ,
79
+ namespace_name : str | None = None ,
79
80
* ,
80
- namespace : Optional [ str ] = None ,
81
- name : Optional [ str ] = None ,
81
+ namespace : str | None = None ,
82
+ name : str | None = None ,
82
83
):
83
84
if namespace_name and bool (namespace or name ):
84
85
raise ValueError (
@@ -116,7 +117,7 @@ def is_valid_dataset_and_class_name(name_str: str) -> bool:
116
117
def parse_builder_name_kwargs (
117
118
name : str ,
118
119
** builder_kwargs : Any ,
119
- ) -> Tuple [DatasetName , Dict [str , Any ]]:
120
+ ) -> tuple [DatasetName , dict [str , Any ]]:
120
121
"""Normalize builder kwargs.
121
122
122
123
Example:
@@ -145,7 +146,7 @@ def parse_builder_name_kwargs(
145
146
146
147
def _dataset_name_and_kwargs_from_name_str (
147
148
name_str : str ,
148
- ) -> Tuple [str , Dict [str , Value ]]:
149
+ ) -> tuple [str , dict [str , Value ]]:
149
150
"""Extract kwargs from name str."""
150
151
err_msg = textwrap .dedent (f"""\
151
152
Parsing builder name string { name_str } failed.
@@ -251,7 +252,7 @@ def get_split(self, split: str) -> str:
251
252
252
253
def dataset_dir (
253
254
self ,
254
- data_dir : Optional [ epath .PathLike ] = None ,
255
+ data_dir : epath .PathLike | None = None ,
255
256
) -> epath .Path :
256
257
"""Returns the path where the data of this dataset lives.
257
258
@@ -286,8 +287,8 @@ def replace(self, **kwargs: Any) -> DatasetReference:
286
287
def from_tfds_name (
287
288
cls ,
288
289
tfds_name : str ,
289
- split_mapping : Optional [ Mapping [str , str ]] = None ,
290
- data_dir : Union [ None , str , os .PathLike ] = None , # pylint: disable=g-bare-generic
290
+ split_mapping : Mapping [str , str ] | None = None ,
291
+ data_dir : str | os .PathLike | None = None , # pylint: disable=g-bare-generic
291
292
) -> DatasetReference :
292
293
"""Returns the `DatasetReference` for the given TFDS dataset."""
293
294
parsed_name , builder_kwargs = parse_builder_name_kwargs (tfds_name )
@@ -392,7 +393,7 @@ def _strip_encoding_suffix(path: str) -> str:
392
393
return path [: path .rfind ('%' )]
393
394
394
395
395
- def _num_digits_needed (num_shards : Optional [ int ] ) -> int :
396
+ def _num_digits_needed (num_shards : int | None ) -> int :
396
397
return max (len (str (num_shards or 0 )), _DEFAULT_NUM_DIGITS_FOR_SHARDS )
397
398
398
399
@@ -447,6 +448,13 @@ def _filename_template_to_regex(filename_template: str) -> str:
447
448
return result
448
449
449
450
451
+ def _regex_for_template (template : str ) -> re .Pattern [str ]:
452
+ """Returns the regular expression for the given template."""
453
+ # Strip the encoding suffix since it is only used for read/write operations.
454
+ template = _strip_encoding_suffix (template )
455
+ return re .compile (_filename_template_to_regex (template ))
456
+
457
+
450
458
@dataclasses .dataclass ()
451
459
class ShardedFileTemplate :
452
460
"""Template to produce filenames for sharded datasets.
@@ -463,9 +471,9 @@ class ShardedFileTemplate:
463
471
464
472
data_dir : epath .Path
465
473
template : str = DEFAULT_FILENAME_TEMPLATE
466
- dataset_name : Optional [ str ] = None
467
- split : Optional [ str ] = None
468
- filetype_suffix : Optional [ str ] = None
474
+ dataset_name : str | None = None
475
+ split : str | None = None
476
+ filetype_suffix : str | None = None
469
477
470
478
def __post_init__ (self ):
471
479
self .data_dir = epath .Path (self .data_dir )
@@ -484,16 +492,14 @@ def __post_init__(self):
484
492
self .template = DEFAULT_FILENAME_TEMPLATE
485
493
486
494
@functools .cached_property
487
- def regex (self ) -> ' re.Pattern[str]' :
495
+ def regex (self ) -> re .Pattern [str ]:
488
496
"""Returns the regular expression for this template.
489
497
490
498
Can be used to test whether a filename matches to this template.
491
499
"""
492
- # Strip the encoding suffix since it is only used for read/write operations.
493
- template = _strip_encoding_suffix (self .template )
494
- return re .compile (_filename_template_to_regex (template ))
500
+ return _regex_for_template (self .template )
495
501
496
- def parse_filename_info (self , filename : str ) -> Optional [ FilenameInfo ] :
502
+ def parse_filename_info (self , filename : str ) -> FilenameInfo | None :
497
503
"""Parses the filename using this template.
498
504
499
505
Note that when the filename doesn't specify the dataset name, split, or
@@ -507,20 +513,26 @@ def parse_filename_info(self, filename: str) -> Optional[FilenameInfo]:
507
513
the FilenameInfo corresponding to the given file if it could be parsed.
508
514
None otherwise.
509
515
"""
510
- match = self .regex .fullmatch (filename )
511
- if not match :
516
+
517
+ def filename_info_from_match (match : re .Match [str ]) -> FilenameInfo :
518
+ groupdict = match .groupdict ()
519
+ shard_index = groupdict .get ('shard_index' )
520
+ num_shards = groupdict .get ('num_shards' )
521
+ return FilenameInfo (
522
+ dataset_name = groupdict .get ('dataset_name' , self .dataset_name ),
523
+ split = groupdict .get ('split' , self .split ),
524
+ filetype_suffix = groupdict .get (
525
+ 'filetype_suffix' , self .filetype_suffix
526
+ ),
527
+ shard_index = int (shard_index ) if shard_index is not None else None ,
528
+ num_shards = int (num_shards ) if num_shards is not None else None ,
529
+ filename_template = self ,
530
+ )
531
+
532
+ if match := self .regex .fullmatch (filename ):
533
+ return filename_info_from_match (match )
534
+ else :
512
535
return None
513
- groupdict = match .groupdict ()
514
- shard_index = groupdict .get ('shard_index' )
515
- num_shards = groupdict .get ('num_shards' )
516
- return FilenameInfo (
517
- dataset_name = groupdict .get ('dataset_name' , self .dataset_name ),
518
- split = groupdict .get ('split' , self .split ),
519
- filetype_suffix = groupdict .get ('filetype_suffix' , self .filetype_suffix ),
520
- shard_index = int (shard_index ) if shard_index is not None else None ,
521
- num_shards = int (num_shards ) if num_shards is not None else None ,
522
- filename_template = self ,
523
- )
524
536
525
537
def is_valid (self , filename : str ) -> bool :
526
538
"""Returns whether the given filename follows this template."""
@@ -564,7 +576,7 @@ def relative_filepath(
564
576
self ,
565
577
* ,
566
578
shard_index : int ,
567
- num_shards : Optional [ int ] ,
579
+ num_shards : int | None ,
568
580
) -> str :
569
581
"""Returns the path (relative to the data dir) of the shard."""
570
582
mappings = self ._default_mappings ()
@@ -590,7 +602,7 @@ def sharded_filepath(
590
602
self ,
591
603
* ,
592
604
shard_index : int ,
593
- num_shards : Optional [ int ] ,
605
+ num_shards : int | None ,
594
606
) -> epath .Path :
595
607
"""Returns the filename (including full path if `data_dir` is set) for the given shard."""
596
608
return self .data_dir / self .relative_filepath (
@@ -600,7 +612,7 @@ def sharded_filepath(
600
612
def sharded_filepaths (
601
613
self ,
602
614
num_shards : int ,
603
- ) -> List [epath .Path ]:
615
+ ) -> list [epath .Path ]:
604
616
return [
605
617
self .sharded_filepath (shard_index = i , num_shards = num_shards )
606
618
for i in range (num_shards )
@@ -617,7 +629,7 @@ def filepath_prefix(
617
629
def sharded_filepaths_pattern (
618
630
self ,
619
631
* ,
620
- num_shards : Optional [ int ] = None ,
632
+ num_shards : int | None = None ,
621
633
) -> str :
622
634
"""Returns a pattern describing all the file paths captured by this template.
623
635
@@ -640,7 +652,7 @@ def sharded_filepaths_pattern(
640
652
replacement = '*'
641
653
return _replace_shard_pattern (os .fspath (a_filepath ), replacement )
642
654
643
- def sharded_filenames (self , num_shards : int ) -> List [str ]:
655
+ def sharded_filenames (self , num_shards : int ) -> list [str ]:
644
656
return [path .name for path in self .sharded_filepaths (num_shards = num_shards )]
645
657
646
658
def replace (self , ** kwargs : Any ) -> 'ShardedFileTemplate' :
@@ -653,8 +665,8 @@ def filepattern_for_dataset_split(
653
665
dataset_name : str ,
654
666
split : str ,
655
667
data_dir : str ,
656
- filetype_suffix : Optional [ str ] = None ,
657
- num_shards : Optional [ int ] = None ,
668
+ filetype_suffix : str | None = None ,
669
+ num_shards : int | None = None ,
658
670
) -> str :
659
671
"""Returns the file pattern for the given dataset.
660
672
@@ -682,8 +694,8 @@ def filenames_for_dataset_split(
682
694
split : str ,
683
695
num_shards : int ,
684
696
filetype_suffix : str ,
685
- data_dir : Optional [ epath .PathLike ] = None ,
686
- ) -> List [str ]:
697
+ data_dir : epath .PathLike | None = None ,
698
+ ) -> list [str ]:
687
699
"""Returns the list of filenames for the given dataset and split."""
688
700
# TODO(tfds): remove this by start using ShardedFileTemplate
689
701
template = ShardedFileTemplate (
@@ -703,7 +715,7 @@ def filepaths_for_dataset_split(
703
715
num_shards : int ,
704
716
data_dir : str ,
705
717
filetype_suffix : str ,
706
- ) -> List [str ]:
718
+ ) -> list [str ]:
707
719
"""File paths of a given dataset split."""
708
720
# TODO(tfds): remove this by start using ShardedFileTemplate
709
721
template = ShardedFileTemplate (
@@ -718,7 +730,7 @@ def filepaths_for_dataset_split(
718
730
719
731
720
732
def _get_filename_template (
721
- filename : str , filename_template : Optional [ ShardedFileTemplate ]
733
+ filename : str , filename_template : ShardedFileTemplate | None
722
734
) -> ShardedFileTemplate :
723
735
if filename_template is None :
724
736
return ShardedFileTemplate (data_dir = epath .Path (os .path .dirname (filename )))
@@ -738,12 +750,12 @@ class FilenameInfo:
738
750
filename_template: the template to which this file conforms.
739
751
"""
740
752
741
- dataset_name : Optional [ str ] = None
742
- split : Optional [ str ] = None
743
- filetype_suffix : Optional [ str ] = None
744
- shard_index : Optional [ int ] = None
745
- num_shards : Optional [ int ] = None
746
- filename_template : Optional [ ShardedFileTemplate ] = None
753
+ dataset_name : str | None = None
754
+ split : str | None = None
755
+ filetype_suffix : str | None = None
756
+ shard_index : int | None = None
757
+ num_shards : int | None = None
758
+ filename_template : ShardedFileTemplate | None = None
747
759
748
760
def full_filename_template (self ):
749
761
template = self .filename_template or ShardedFileTemplate (
@@ -763,7 +775,7 @@ def replace(self, **kwargs: Any) -> 'FilenameInfo':
763
775
def from_str (
764
776
cls ,
765
777
filename : str ,
766
- filename_template : Optional [ ShardedFileTemplate ] = None ,
778
+ filename_template : ShardedFileTemplate | None = None ,
767
779
) -> 'FilenameInfo' :
768
780
"""Factory to create a `FilenameInfo` from filename."""
769
781
filename_template = _get_filename_template (filename , filename_template )
@@ -780,7 +792,7 @@ def from_str(
780
792
@staticmethod
781
793
def is_valid (
782
794
filename : str ,
783
- filename_template : Optional [ ShardedFileTemplate ] = None ,
795
+ filename_template : ShardedFileTemplate | None = None ,
784
796
) -> bool :
785
797
"""Returns True if the filename follow the given pattern."""
786
798
filename_template = _get_filename_template (filename , filename_template )
0 commit comments