Skip to content

Commit 32790c8

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Use modern Python type annotations
PiperOrigin-RevId: 683575259
1 parent 43c7381 commit 32790c8

File tree

1 file changed

+64
-52
lines changed

1 file changed

+64
-52
lines changed

tensorflow_datasets/core/naming.py

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Mapping, MutableMapping
2021
import dataclasses
2122
import functools
2223
import os
2324
import re
2425
import textwrap
25-
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
26+
from typing import Any
2627

2728
from etils import epath
2829
from tensorflow_datasets.core.utils import py_utils
@@ -63,22 +64,22 @@
6364
_first_cap_re = re.compile('(.)([A-Z][a-z0-9]+)')
6465
_all_cap_re = re.compile('([a-z0-9])([A-Z])')
6566

66-
Value = Union[str, int, float, bool]
67+
Value = str | int | float | bool
6768

6869

6970
@dataclasses.dataclass(eq=True, order=True, frozen=True)
7071
class DatasetName:
7172
"""Dataset namespace+name."""
7273

73-
namespace: Optional[str]
74+
namespace: str | None
7475
name: str
7576

7677
def __init__(
7778
self,
78-
namespace_name: Optional[str] = None,
79+
namespace_name: str | None = None,
7980
*,
80-
namespace: Optional[str] = None,
81-
name: Optional[str] = None,
81+
namespace: str | None = None,
82+
name: str | None = None,
8283
):
8384
if namespace_name and bool(namespace or name):
8485
raise ValueError(
@@ -116,7 +117,7 @@ def is_valid_dataset_and_class_name(name_str: str) -> bool:
116117
def parse_builder_name_kwargs(
117118
name: str,
118119
**builder_kwargs: Any,
119-
) -> Tuple[DatasetName, Dict[str, Any]]:
120+
) -> tuple[DatasetName, dict[str, Any]]:
120121
"""Normalize builder kwargs.
121122
122123
Example:
@@ -145,7 +146,7 @@ def parse_builder_name_kwargs(
145146

146147
def _dataset_name_and_kwargs_from_name_str(
147148
name_str: str,
148-
) -> Tuple[str, Dict[str, Value]]:
149+
) -> tuple[str, dict[str, Value]]:
149150
"""Extract kwargs from name str."""
150151
err_msg = textwrap.dedent(f"""\
151152
Parsing builder name string {name_str} failed.
@@ -251,7 +252,7 @@ def get_split(self, split: str) -> str:
251252

252253
def dataset_dir(
253254
self,
254-
data_dir: Optional[epath.PathLike] = None,
255+
data_dir: epath.PathLike | None = None,
255256
) -> epath.Path:
256257
"""Returns the path where the data of this dataset lives.
257258
@@ -286,8 +287,8 @@ def replace(self, **kwargs: Any) -> DatasetReference:
286287
def from_tfds_name(
287288
cls,
288289
tfds_name: str,
289-
split_mapping: Optional[Mapping[str, str]] = None,
290-
data_dir: Union[None, str, os.PathLike] = None, # pylint: disable=g-bare-generic
290+
split_mapping: Mapping[str, str] | None = None,
291+
data_dir: str | os.PathLike | None = None, # pylint: disable=g-bare-generic
291292
) -> DatasetReference:
292293
"""Returns the `DatasetReference` for the given TFDS dataset."""
293294
parsed_name, builder_kwargs = parse_builder_name_kwargs(tfds_name)
@@ -392,7 +393,7 @@ def _strip_encoding_suffix(path: str) -> str:
392393
return path[: path.rfind('%')]
393394

394395

395-
def _num_digits_needed(num_shards: Optional[int]) -> int:
396+
def _num_digits_needed(num_shards: int | None) -> int:
396397
return max(len(str(num_shards or 0)), _DEFAULT_NUM_DIGITS_FOR_SHARDS)
397398

398399

@@ -447,6 +448,13 @@ def _filename_template_to_regex(filename_template: str) -> str:
447448
return result
448449

449450

451+
def _regex_for_template(template: str) -> re.Pattern[str]:
452+
"""Returns the regular expression for the given template."""
453+
# Strip the encoding suffix since it is only used for read/write operations.
454+
template = _strip_encoding_suffix(template)
455+
return re.compile(_filename_template_to_regex(template))
456+
457+
450458
@dataclasses.dataclass()
451459
class ShardedFileTemplate:
452460
"""Template to produce filenames for sharded datasets.
@@ -463,9 +471,9 @@ class ShardedFileTemplate:
463471

464472
data_dir: epath.Path
465473
template: str = DEFAULT_FILENAME_TEMPLATE
466-
dataset_name: Optional[str] = None
467-
split: Optional[str] = None
468-
filetype_suffix: Optional[str] = None
474+
dataset_name: str | None = None
475+
split: str | None = None
476+
filetype_suffix: str | None = None
469477

470478
def __post_init__(self):
471479
self.data_dir = epath.Path(self.data_dir)
@@ -484,16 +492,14 @@ def __post_init__(self):
484492
self.template = DEFAULT_FILENAME_TEMPLATE
485493

486494
@functools.cached_property
487-
def regex(self) -> 're.Pattern[str]':
495+
def regex(self) -> re.Pattern[str]:
488496
"""Returns the regular expression for this template.
489497
490498
Can be used to test whether a filename matches to this template.
491499
"""
492-
# Strip the encoding suffix since it is only used for read/write operations.
493-
template = _strip_encoding_suffix(self.template)
494-
return re.compile(_filename_template_to_regex(template))
500+
return _regex_for_template(self.template)
495501

496-
def parse_filename_info(self, filename: str) -> Optional[FilenameInfo]:
502+
def parse_filename_info(self, filename: str) -> FilenameInfo | None:
497503
"""Parses the filename using this template.
498504
499505
Note that when the filename doesn't specify the dataset name, split, or
@@ -507,20 +513,26 @@ def parse_filename_info(self, filename: str) -> Optional[FilenameInfo]:
507513
the FilenameInfo corresponding to the given file if it could be parsed.
508514
None otherwise.
509515
"""
510-
match = self.regex.fullmatch(filename)
511-
if not match:
516+
517+
def filename_info_from_match(match: re.Match[str]) -> FilenameInfo:
518+
groupdict = match.groupdict()
519+
shard_index = groupdict.get('shard_index')
520+
num_shards = groupdict.get('num_shards')
521+
return FilenameInfo(
522+
dataset_name=groupdict.get('dataset_name', self.dataset_name),
523+
split=groupdict.get('split', self.split),
524+
filetype_suffix=groupdict.get(
525+
'filetype_suffix', self.filetype_suffix
526+
),
527+
shard_index=int(shard_index) if shard_index is not None else None,
528+
num_shards=int(num_shards) if num_shards is not None else None,
529+
filename_template=self,
530+
)
531+
532+
if match := self.regex.fullmatch(filename):
533+
return filename_info_from_match(match)
534+
else:
512535
return None
513-
groupdict = match.groupdict()
514-
shard_index = groupdict.get('shard_index')
515-
num_shards = groupdict.get('num_shards')
516-
return FilenameInfo(
517-
dataset_name=groupdict.get('dataset_name', self.dataset_name),
518-
split=groupdict.get('split', self.split),
519-
filetype_suffix=groupdict.get('filetype_suffix', self.filetype_suffix),
520-
shard_index=int(shard_index) if shard_index is not None else None,
521-
num_shards=int(num_shards) if num_shards is not None else None,
522-
filename_template=self,
523-
)
524536

525537
def is_valid(self, filename: str) -> bool:
526538
"""Returns whether the given filename follows this template."""
@@ -564,7 +576,7 @@ def relative_filepath(
564576
self,
565577
*,
566578
shard_index: int,
567-
num_shards: Optional[int],
579+
num_shards: int | None,
568580
) -> str:
569581
"""Returns the path (relative to the data dir) of the shard."""
570582
mappings = self._default_mappings()
@@ -590,7 +602,7 @@ def sharded_filepath(
590602
self,
591603
*,
592604
shard_index: int,
593-
num_shards: Optional[int],
605+
num_shards: int | None,
594606
) -> epath.Path:
595607
"""Returns the filename (including full path if `data_dir` is set) for the given shard."""
596608
return self.data_dir / self.relative_filepath(
@@ -600,7 +612,7 @@ def sharded_filepath(
600612
def sharded_filepaths(
601613
self,
602614
num_shards: int,
603-
) -> List[epath.Path]:
615+
) -> list[epath.Path]:
604616
return [
605617
self.sharded_filepath(shard_index=i, num_shards=num_shards)
606618
for i in range(num_shards)
@@ -617,7 +629,7 @@ def filepath_prefix(
617629
def sharded_filepaths_pattern(
618630
self,
619631
*,
620-
num_shards: Optional[int] = None,
632+
num_shards: int | None = None,
621633
) -> str:
622634
"""Returns a pattern describing all the file paths captured by this template.
623635
@@ -640,7 +652,7 @@ def sharded_filepaths_pattern(
640652
replacement = '*'
641653
return _replace_shard_pattern(os.fspath(a_filepath), replacement)
642654

643-
def sharded_filenames(self, num_shards: int) -> List[str]:
655+
def sharded_filenames(self, num_shards: int) -> list[str]:
644656
return [path.name for path in self.sharded_filepaths(num_shards=num_shards)]
645657

646658
def replace(self, **kwargs: Any) -> 'ShardedFileTemplate':
@@ -653,8 +665,8 @@ def filepattern_for_dataset_split(
653665
dataset_name: str,
654666
split: str,
655667
data_dir: str,
656-
filetype_suffix: Optional[str] = None,
657-
num_shards: Optional[int] = None,
668+
filetype_suffix: str | None = None,
669+
num_shards: int | None = None,
658670
) -> str:
659671
"""Returns the file pattern for the given dataset.
660672
@@ -682,8 +694,8 @@ def filenames_for_dataset_split(
682694
split: str,
683695
num_shards: int,
684696
filetype_suffix: str,
685-
data_dir: Optional[epath.PathLike] = None,
686-
) -> List[str]:
697+
data_dir: epath.PathLike | None = None,
698+
) -> list[str]:
687699
"""Returns the list of filenames for the given dataset and split."""
688700
# TODO(tfds): remove this by start using ShardedFileTemplate
689701
template = ShardedFileTemplate(
@@ -703,7 +715,7 @@ def filepaths_for_dataset_split(
703715
num_shards: int,
704716
data_dir: str,
705717
filetype_suffix: str,
706-
) -> List[str]:
718+
) -> list[str]:
707719
"""File paths of a given dataset split."""
708720
# TODO(tfds): remove this by start using ShardedFileTemplate
709721
template = ShardedFileTemplate(
@@ -718,7 +730,7 @@ def filepaths_for_dataset_split(
718730

719731

720732
def _get_filename_template(
721-
filename: str, filename_template: Optional[ShardedFileTemplate]
733+
filename: str, filename_template: ShardedFileTemplate | None
722734
) -> ShardedFileTemplate:
723735
if filename_template is None:
724736
return ShardedFileTemplate(data_dir=epath.Path(os.path.dirname(filename)))
@@ -738,12 +750,12 @@ class FilenameInfo:
738750
filename_template: the template to which this file conforms.
739751
"""
740752

741-
dataset_name: Optional[str] = None
742-
split: Optional[str] = None
743-
filetype_suffix: Optional[str] = None
744-
shard_index: Optional[int] = None
745-
num_shards: Optional[int] = None
746-
filename_template: Optional[ShardedFileTemplate] = None
753+
dataset_name: str | None = None
754+
split: str | None = None
755+
filetype_suffix: str | None = None
756+
shard_index: int | None = None
757+
num_shards: int | None = None
758+
filename_template: ShardedFileTemplate | None = None
747759

748760
def full_filename_template(self):
749761
template = self.filename_template or ShardedFileTemplate(
@@ -763,7 +775,7 @@ def replace(self, **kwargs: Any) -> 'FilenameInfo':
763775
def from_str(
764776
cls,
765777
filename: str,
766-
filename_template: Optional[ShardedFileTemplate] = None,
778+
filename_template: ShardedFileTemplate | None = None,
767779
) -> 'FilenameInfo':
768780
"""Factory to create a `FilenameInfo` from filename."""
769781
filename_template = _get_filename_template(filename, filename_template)
@@ -780,7 +792,7 @@ def from_str(
780792
@staticmethod
781793
def is_valid(
782794
filename: str,
783-
filename_template: Optional[ShardedFileTemplate] = None,
795+
filename_template: ShardedFileTemplate | None = None,
784796
) -> bool:
785797
"""Returns True if the filename follow the given pattern."""
786798
filename_template = _get_filename_template(filename, filename_template)

0 commit comments

Comments
 (0)