Skip to content

Commit c96a3ed

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Add VersionOrStr type annotations.
PiperOrigin-RevId: 693252714
1 parent 84883f8 commit c96a3ed

File tree

5 files changed

+22
-10
lines changed

5 files changed

+22
-10
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
ListOrTreeOrElem = type_utils.ListOrTreeOrElem
7070
Tree = type_utils.Tree
7171
TreeDict = type_utils.TreeDict
72-
VersionOrStr = utils.Version | str
7372

7473
FORCE_REDOWNLOAD = download.GenerateMode.FORCE_REDOWNLOAD
7574
REUSE_CACHE_IF_EXISTS = download.GenerateMode.REUSE_CACHE_IF_EXISTS
@@ -107,9 +106,9 @@ class BuilderConfig:
107106
# * Kwargs-only (https://bugs.python.org/issue33129)
108107

109108
name: str
110-
version: VersionOrStr | None = None
109+
version: utils.VersionOrStr | None = None
111110
release_notes: dict[str, str] | None = None
112-
supported_versions: list[VersionOrStr] = dataclasses.field(
111+
supported_versions: list[utils.VersionOrStr] = dataclasses.field(
113112
default_factory=list
114113
)
115114
description: str | None = None

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from tensorflow_datasets.core.utils import conversion_utils
5656
from tensorflow_datasets.core.utils import croissant_utils
5757
from tensorflow_datasets.core.utils import type_utils
58-
from tensorflow_datasets.core.utils import version as version_utils
58+
from tensorflow_datasets.core.utils import version as version_lib
5959
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
6060
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
6161
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
@@ -165,7 +165,7 @@ def __init__(
165165
int_dtype: type_utils.TfdsDType | None = np.int64,
166166
float_dtype: type_utils.TfdsDType | None = np.float32,
167167
mapping: Mapping[str, epath.PathLike] | None = None,
168-
overwrite_version: str | None = None,
168+
overwrite_version: version_lib.VersionOrStr | None = None,
169169
filters: Mapping[str, Any] | None = None,
170170
**kwargs: Any,
171171
):
@@ -203,7 +203,7 @@ def __init__(
203203
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
204204
# recommended attribute. If the version is unspecified in Croissant, we set
205205
# it to `1.0.0` in TFDS.
206-
self.VERSION = version_utils.Version( # pylint: disable=invalid-name
206+
self.VERSION = version_lib.Version( # pylint: disable=invalid-name
207207
overwrite_version or self.dataset.metadata.version or '1.0.0'
208208
)
209209
self.RELEASE_NOTES = {} # pylint: disable=invalid-name

tensorflow_datasets/core/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,4 @@
8888
from tensorflow_datasets.core.utils.version import Experiment
8989
from tensorflow_datasets.core.utils.version import IsBlocked
9090
from tensorflow_datasets.core.utils.version import Version
91+
from tensorflow_datasets.core.utils.version import VersionOrStr

tensorflow_datasets/core/utils/version.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import dataclasses
2121
import enum
2222
import re
23+
from typing import Union
2324

2425
from etils import epath
2526

@@ -39,6 +40,7 @@ class DatasetVariantBlockedError(ValueError):
3940
# The key is a version or config string, the value is a short sentence
4041
# explaining why that version or config should not be used (or None).
4142
BlockedWithMsg = dict[str, str | None]
43+
VersionOrStr = Union["Version", str]
4244

4345

4446
@dataclasses.dataclass(frozen=True)
@@ -73,7 +75,7 @@ class BlockedVersions:
7375
configs: dict[str, BlockedWithMsg] = dataclasses.field(default_factory=dict)
7476

7577
def is_blocked(
76-
self, version: str | Version, config: str | None = None
78+
self, version: VersionOrStr, config: str | None = None
7779
) -> IsBlocked:
7880
"""Checks whether a version or config is blocked.
7981
@@ -131,7 +133,7 @@ class Version:
131133

132134
def __init__(
133135
self,
134-
version: Version | str,
136+
version: VersionOrStr,
135137
experiments=None,
136138
tfds_version_to_prepare=None,
137139
):
@@ -227,7 +229,7 @@ def match(self, other_version) -> bool:
227229
)
228230

229231
@classmethod
230-
def is_valid(cls, version: Version | str | None) -> bool:
232+
def is_valid(cls, version: VersionOrStr | None) -> bool:
231233
"""Returns True if the version can be parsed."""
232234
if isinstance(version, Version):
233235
return True

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from tensorflow_datasets.core import file_adapters
3939
from tensorflow_datasets.core.dataset_builders import croissant_builder
4040
from tensorflow_datasets.core.utils import croissant_utils
41+
from tensorflow_datasets.core.utils import version as version_lib
4142
from tensorflow_datasets.scripts.cli import cli_utils
4243

4344

@@ -112,6 +113,15 @@ def record_set_ids(self) -> list[str]:
112113
self.dataset.metadata
113114
)
114115

116+
@functools.cached_property
117+
def version(self) -> version_lib.Version:
118+
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
119+
# recommended attribute. If the version is unspecified in Croissant, we set
120+
# it to `1.0.0` in TFDS.
121+
return version_lib.Version(
122+
self.overwrite_version or self.dataset.metadata.version or '1.0.0'
123+
)
124+
115125

116126
def register_subparser(parsers: argparse._SubParsersAction):
117127
"""Add subparser for `convert_format` command."""
@@ -146,7 +156,7 @@ def prepare_croissant_builder(
146156
file_format=args.file_format,
147157
data_dir=args.data_dir,
148158
mapping=args.mapping_json,
149-
overwrite_version=args.overwrite_version,
159+
overwrite_version=args.version,
150160
)
151161
cli_utils.download_and_prepare(
152162
builder=builder,

0 commit comments

Comments
 (0)