Skip to content

Commit d7c97ea

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Fix pytype errors.
PiperOrigin-RevId: 618854330
1 parent a4131b3 commit d7c97ea

File tree

5 files changed

+34
-27
lines changed

5 files changed

+34
-27
lines changed

tensorflow_datasets/core/decode/partial_decode.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ def _normalize_feature_dict(
118118
inner_features = {
119119
k: v for k, v in expected_feature.items() if v is not False # pylint: disable=g-bool-id-comparison
120120
}
121+
feature = typing.cast(features_lib.FeaturesDict, feature)
121122
inner_features = { # Extract the feature subset # pylint: disable=g-complex-comprehension
122-
k: _extract_feature_item( # pytype: disable=wrong-arg-types # always-use-return-annotations
123+
k: _extract_feature_item(
123124
feature=feature,
124125
expected_key=k,
125126
expected_value=v,
@@ -153,18 +154,17 @@ def _extract_features(
153154
# Recurse into FeaturesDict, Sequence
154155
# Use `type` rather than `isinstance` to not recurse into inherited classes.
155156
if type(feature) == features_lib.FeaturesDict: # pylint: disable=unidiomatic-typecheck
157+
feature = typing.cast(features_lib.FeaturesDict, feature)
156158
expected_feature = typing.cast(features_lib.FeaturesDict, expected_feature)
157-
return features_lib.FeaturesDict(
158-
{ # Extract the feature subset # pylint: disable=g-complex-comprehension
159-
k: _extract_feature_item( # pytype: disable=wrong-arg-types # always-use-return-annotations
160-
feature=feature,
161-
expected_key=k,
162-
expected_value=v,
163-
fn=_extract_features,
164-
)
165-
for k, v in expected_feature.items()
166-
}
167-
)
159+
return features_lib.FeaturesDict({ # Extract the feature subset # pylint: disable=g-complex-comprehension
160+
k: _extract_feature_item(
161+
feature=feature,
162+
expected_key=k,
163+
expected_value=v,
164+
fn=_extract_features,
165+
)
166+
for k, v in expected_feature.items()
167+
})
168168
elif type(feature) == features_lib.Sequence: # pylint: disable=unidiomatic-typecheck
169169
feature = typing.cast(features_lib.Sequence, feature)
170170
expected_feature = typing.cast(features_lib.Sequence, expected_feature)

tensorflow_datasets/core/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def load_all_datasets(
433433
`dict` of `dataset_names` mapping to a `dict` of {`split_name`:
434434
tf.data.Dataset} for each desired datasets.
435435
"""
436-
return self.load_datasets( # pytype: disable=wrong-arg-types
436+
return self.load_datasets(
437437
datasets=self.datasets.keys(), split=split, loader_kwargs=loader_kwargs
438438
)
439439

tensorflow_datasets/datasets/kddcup99/kddcup99_dataset_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,5 @@ def _generate_examples(self, gz_path):
245245
row['root_shell'] = bool_utils.parse_bool(row['root_shell'])
246246
row['is_hot_login'] = bool_utils.parse_bool(row['is_hot_login'])
247247
row['is_guest_login'] = bool_utils.parse_bool(row['is_guest_login'])
248-
row['label'] = row['label'].rstrip('.') # pytype: disable=attribute-error
248+
row['label'] = str(row['label']).rstrip('.')
249249
yield index, row

tensorflow_datasets/text/gsm8k/gsm8k.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def _info(self) -> tfds.core.DatasetInfo:
6868
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
6969
"""Returns SplitGenerators."""
7070
extracted = dl_manager.download_and_extract(_URLS)
71-
return {k: self._generate_examples(v) for k, v in extracted.items()} # pytype: disable=wrong-arg-types # always-use-return-annotations
71+
return {k: self._generate_examples(v) for k, v in extracted.items()}
7272

73-
def _generate_examples(self, path: str):
73+
def _generate_examples(self, path: epath.PathLike):
7474
"""Yields examples."""
7575
with epath.Path(path).open() as f:
7676
for i, line in enumerate(f):

tensorflow_datasets/video/tao/tao.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
from __future__ import annotations
1919

2020
import collections
21+
from collections.abc import Mapping
2122
import json
2223
import os
23-
from typing import Any, Dict, Optional, Tuple
24+
from typing import Any
2425

2526
from etils import epath
2627
import numpy as np
@@ -52,12 +53,12 @@
5253
}
5354
"""
5455

55-
NestedDict = Dict[str, Any]
56+
NestedDict = Mapping[str, Any]
5657

5758

5859
def _build_annotations_index(
5960
annotations: NestedDict,
60-
) -> Tuple[NestedDict, NestedDict, NestedDict, NestedDict]:
61+
) -> tuple[NestedDict, NestedDict, NestedDict, NestedDict]:
6162
"""Builds several dictionaries to aid in looking up annotations."""
6263
vids = {x['id']: x for x in annotations['videos']}
6364
images = {x['id']: x for x in annotations['images']}
@@ -72,7 +73,7 @@ def _build_annotations_index(
7273
return vids, ann_to_images, track_to_anns, vid_to_tracks
7374

7475

75-
def _merge_categories_map(annotations: NestedDict) -> Dict[str, str]:
76+
def _merge_categories_map(annotations: NestedDict) -> dict[str, str]:
7677
"""Some categories should be renamed into others.
7778
7879
This code segment is based on the TAO provided preprocessing API.
@@ -91,7 +92,9 @@ def _merge_categories_map(annotations: NestedDict) -> Dict[str, str]:
9192
return merge_map
9293

9394

94-
def _maybe_prepare_manual_data(dl_manager: tfds.download.DownloadManager):
95+
def _maybe_prepare_manual_data(
96+
dl_manager: tfds.download.DownloadManager,
97+
) -> tuple[epath.Path | None, epath.Path | None]:
9598
"""Return paths to the manually downloaded data if it is available."""
9699

97100
# The file has a different name each time it is downloaded.
@@ -115,7 +118,7 @@ def _maybe_prepare_manual_data(dl_manager: tfds.download.DownloadManager):
115118
return dl_manager.extract(files)
116119

117120

118-
def _get_category_id_map(annotations_root) -> Dict[str, int]:
121+
def _get_category_id_map(annotations_root) -> dict[str, int]:
119122
"""Gets a map from the TAO category id to a tfds category index.
120123
121124
The tfds category index is the index which a category appears in the
@@ -150,7 +153,7 @@ def _get_category_id_map(annotations_root) -> Dict[str, int]:
150153

151154

152155
def _preprocess_annotations(
153-
annotations_file: str, id_map: Dict[int, int]
156+
annotations_file: str, id_map: dict[str, int]
154157
) -> NestedDict:
155158
"""Preprocesses the data to group together some category labels."""
156159
with epath.Path(annotations_file).open('r') as f:
@@ -226,8 +229,8 @@ class TaoConfig(tfds.core.BuilderConfig):
226229
def __init__(
227230
self,
228231
*,
229-
height: Optional[int] = None,
230-
width: Optional[int] = None,
232+
height: int | None = None,
233+
width: int | None = None,
231234
**kwargs,
232235
):
233236
"""The parameters specifying how the dataset will be processed.
@@ -391,11 +394,15 @@ def _create_metadata(
391394
return metadata
392395

393396
def _generate_examples(
394-
self, data_path, manual_path, annotations_path, id_map
397+
self,
398+
data_path: epath.PathLike,
399+
manual_path: epath.Path | None,
400+
annotations_path: epath.Path,
401+
id_map: dict[str, int],
395402
):
396403
"""Yields examples."""
397404
beam = tfds.core.lazy_imports.apache_beam
398-
annotations = _preprocess_annotations(annotations_path, id_map) # pytype: disable=wrong-arg-types # always-use-return-annotations
405+
annotations = _preprocess_annotations(os.fspath(annotations_path), id_map)
399406
outs = _build_annotations_index(annotations)
400407
vids, ann_to_images, track_to_anns, vid_to_tracks = outs
401408

0 commit comments

Comments
 (0)