Skip to content

Commit d5af645

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Fix types in dataset_info.py
PiperOrigin-RevId: 685400815
1 parent 565eb47 commit d5af645

File tree

2 files changed

+25
-28
lines changed

2 files changed

+25
-28
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _info(self) -> dataset_info.DatasetInfo:
238238
disable_shuffling=self._disable_shuffling,
239239
)
240240

241-
def get_features(self) -> Optional[feature_lib.FeatureConnector]:
241+
def get_features(self) -> features_dict.FeaturesDict:
242242
"""Infers the features dict for the required record set."""
243243
record_set = croissant_utils.get_record_set(
244244
self.builder_config.name, metadata=self.metadata

tensorflow_datasets/core/dataset_info.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import posixpath
4141
import tempfile
4242
import time
43-
from typing import Any, Optional
43+
from typing import Any
4444

4545
from absl import logging
4646
from etils import epath
@@ -66,8 +66,7 @@
6666
# pylint: enable=g-import-not-at-top
6767

6868

69-
# TODO(b/109648354): Remove the "pytype: disable" comment.
70-
Nest = tuple["Nest", ...] | dict[str, "Nest"] | str # pytype: disable=not-supported-yet
69+
Nest = tuple["Nest", ...] | dict[str, "Nest"] | str
7170
SupervisedKeysType = tuple[Nest, Nest] | tuple[Nest, Nest, Nest]
7271

7372

@@ -104,7 +103,7 @@ def load_metadata(self, data_dir):
104103
raise NotImplementedError()
105104

106105

107-
@dataclasses.dataclass()
106+
@dataclasses.dataclass
108107
class DatasetIdentity:
109108
"""Identity of a dataset that completely identifies a dataset."""
110109

@@ -167,7 +166,7 @@ def from_proto(
167166
)
168167

169168

170-
class DatasetInfo(object):
169+
class DatasetInfo:
171170
"""Information about a dataset.
172171
173172
`DatasetInfo` documents datasets, including its name, version, and features.
@@ -185,15 +184,15 @@ def __init__(
185184
*,
186185
builder: DatasetIdentity | Any,
187186
description: str | None = None,
188-
features: Optional[feature_lib.FeatureConnector] = None,
189-
supervised_keys: Optional[SupervisedKeysType] = None,
187+
features: feature_lib.FeatureConnector | None = None,
188+
supervised_keys: SupervisedKeysType | None = None,
190189
disable_shuffling: bool = False,
191190
homepage: str | None = None,
192191
citation: str | None = None,
193192
metadata: Metadata | None = None,
194193
license: str | None = None, # pylint: disable=redefined-builtin
195-
redistribution_info: Optional[dict[str, str]] = None,
196-
split_dict: Optional[splits_lib.SplitDict] = None,
194+
redistribution_info: dict[str, str] | None = None,
195+
split_dict: splits_lib.SplitDict | None = None,
197196
alternative_file_formats: (
198197
Sequence[str | file_adapters.FileFormat] | None
199198
) = None,
@@ -403,7 +402,7 @@ def disable_shuffling(self) -> bool:
403402
return self.as_proto.disable_shuffling
404403

405404
@property
406-
def homepage(self):
405+
def homepage(self) -> str:
407406
urls = self.as_proto.location.urls
408407
tfds_homepage = f"https://www.tensorflow.org/datasets/catalog/{self.name}"
409408
return urls and urls[0] or tfds_homepage
@@ -413,7 +412,7 @@ def citation(self) -> str:
413412
return self.as_proto.citation
414413

415414
@property
416-
def data_dir(self):
415+
def data_dir(self) -> str:
417416
return self._identity.data_dir
418417

419418
@property
@@ -431,15 +430,15 @@ def download_size(self) -> utils.Size:
431430
)
432431

433432
@download_size.setter
434-
def download_size(self, size):
433+
def download_size(self, size: int):
435434
self.as_proto.download_size = size
436435

437436
@property
438437
def features(self):
439438
return self._features
440439

441440
@property
442-
def alternative_file_formats(self) -> Sequence[file_adapters.FileFormat]:
441+
def alternative_file_formats(self) -> list[file_adapters.FileFormat]:
443442
return self._alternative_file_formats
444443

445444
@property
@@ -454,7 +453,7 @@ def set_is_blocked(self, is_blocked: str) -> None:
454453
self._is_blocked = is_blocked
455454

456455
@property
457-
def supervised_keys(self) -> Optional[SupervisedKeysType]:
456+
def supervised_keys(self) -> SupervisedKeysType | None:
458457
if not self.as_proto.HasField("supervised_keys"):
459458
return None
460459
supervised_keys = self.as_proto.supervised_keys
@@ -576,8 +575,8 @@ def set_splits(self, split_dict: splits_lib.SplitDict) -> None:
576575
# into the new split_dict. Also add the filename template if it's not set.
577576
new_split_infos = []
578577
incomplete_filename_template = naming.ShardedFileTemplate(
578+
data_dir=epath.Path(self.data_dir),
579579
dataset_name=self.name,
580-
data_dir=self.data_dir,
581580
filetype_suffix=(
582581
self.as_proto.file_format or file_adapters.DEFAULT_FILE_FORMAT.value
583582
),
@@ -728,22 +727,20 @@ def read_from_directory(self, dataset_info_dir: epath.PathLike) -> None:
728727

729728
# Restore the feature metadata (vocabulary, labels names,...)
730729
if self.features:
731-
self.features.load_metadata(dataset_info_dir) # pytype: disable=missing-parameter # always-use-property-annotation
730+
self.features.load_metadata(dataset_info_dir, feature_name=None)
732731
# For `ReadOnlyBuilder`, reconstruct the features from the config.
733732
elif feature_lib.make_config_path(dataset_info_dir).exists():
734-
self._features = feature_lib.FeatureConnector.from_config(
733+
self._features = top_level_feature.TopLevelFeature.from_config(
735734
dataset_info_dir
736735
)
736+
737+
# If the dataset was loaded from file, self.metadata will be `None`, so
738+
# we create a MetadataDict first.
739+
if not self._metadata:
740+
self._metadata = MetadataDict()
737741
# Restore the MetaDataDict from metadata.json if there is any
738-
if (
739-
self.metadata is not None
740-
or _metadata_filepath(dataset_info_dir).exists()
741-
):
742-
# If the dataset was loaded from file, self.metadata will be `None`, so
743-
# we create a MetadataDict first.
744-
if self.metadata is None:
745-
self._metadata = MetadataDict()
746-
self.metadata.load_metadata(dataset_info_dir) # pytype: disable=attribute-error # always-use-property-annotation
742+
if _metadata_filepath(dataset_info_dir).exists():
743+
self._metadata.load_metadata(dataset_info_dir)
747744

748745
# Update fields which are not defined in the code. This means that
749746
# the code will overwrite fields which are present in
@@ -1215,7 +1212,7 @@ def pack_as_supervised_ds(
12151212
and isinstance(ds.element_spec, tuple)
12161213
and len(ds.element_spec) == 2
12171214
):
1218-
x_key, y_key = ds_info.supervised_keys # pytype: disable=bad-unpacking # always-use-property-annotation
1215+
x_key, y_key = ds_info.supervised_keys # pytype: disable=bad-unpacking
12191216
ds = ds.map(lambda x, y: {x_key: x, y_key: y})
12201217
return ds
12211218
else: # If dataset isn't a supervised tuple (input, label), return as-is

0 commit comments

Comments
 (0)