Skip to content

Commit 84883f8

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Fix type annotations for ClassLabelFeature.names
PiperOrigin-RevId: 692975782
1 parent 96f8363 commit 84883f8

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

tensorflow_datasets/core/features/class_label_feature.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""ClassLabel feature."""
1717

18+
from collections.abc import Iterable
1819
from typing import Optional, Union
1920

2021
from etils import epath
@@ -35,8 +36,8 @@ def __init__(
3536
self,
3637
*,
3738
num_classes: int | None = None,
38-
names=None,
39-
names_file: str | epath.PathLike | None = None,
39+
names: Iterable[str] | None = None,
40+
names_file: epath.PathLike | None = None,
4041
doc: feature_lib.DocArg = None,
4142
):
4243
"""Constructs a ClassLabel FeatureConnector.
@@ -51,11 +52,11 @@ def __init__(
5152
Note: On python2, the strings are encoded as utf-8.
5253
5354
Args:
54-
num_classes: `int`, number of classes. All labels must be < num_classes.
55-
names: `list<str>`, string names for the integer classes. The order in
56-
which the names are provided is kept.
57-
names_file: `str` or `epath.PathLike`, path to a file with names for the
58-
integer classes, one per line.
55+
num_classes: Number of classes. All labels must be < num_classes.
56+
names: String names for the integer classes. The order in which the names
57+
are provided is kept.
58+
names_file: Path to a file with names for the integer classes, one per
59+
line.
5960
doc: Documentation of this feature (e.g. description).
6061
"""
6162
super(ClassLabel, self).__init__(shape=(), dtype=np.int64, doc=doc)
@@ -76,7 +77,7 @@ def __init__(
7677
if num_classes is not None:
7778
self._num_classes = num_classes
7879
elif names is not None:
79-
self.names = names
80+
self.names = list(names)
8081
elif names_file is not None:
8182
self.names = _load_names_from_file(epath.Path(names_file))
8283

@@ -85,10 +86,10 @@ def num_classes(self) -> Optional[int]:
8586
return self._num_classes
8687

8788
@property
88-
def names(self) -> Optional[list[str]]:
89+
def names(self) -> list[str]:
8990
if not self._int2str:
9091
return [str(i) for i in range(self._num_classes)]
91-
return list(self._int2str)
92+
return self._int2str
9293

9394
@names.setter
9495
def names(self, new_names: list[str]):
@@ -224,5 +225,7 @@ def _load_names_from_file(names_filepath: epath.Path) -> list[str]:
224225
]
225226

226227

227-
def _write_names_to_file(names_filepath: epath.Path, names) -> None:
228+
def _write_names_to_file(
229+
names_filepath: epath.Path, names: Iterable[str]
230+
) -> None:
228231
names_filepath.write_text("\n".join(names) + "\n")

tensorflow_datasets/core/utils/huggingface_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def convert_hf_features(hf_features) -> feature_lib.FeatureConnector:
9999
return feature_lib.Scalar(dtype=_convert_to_np_dtype(hf_features.dtype))
100100
case hf_datasets.ClassLabel():
101101
if hf_features.names:
102-
return feature_lib.ClassLabel(names=hf_features.names)
102+
return feature_lib.ClassLabel(
103+
names=[str(name) for name in hf_features.names]
104+
)
103105
if hf_features.names_file:
104106
return feature_lib.ClassLabel(names_file=hf_features.names_file)
105107
if hf_features.num_classes:

0 commit comments

Comments
 (0)