15
15
16
16
"""ClassLabel feature."""
17
17
18
+ from collections .abc import Iterable
18
19
from typing import Optional , Union
19
20
20
21
from etils import epath
@@ -35,8 +36,8 @@ def __init__(
35
36
self ,
36
37
* ,
37
38
num_classes : int | None = None ,
38
- names = None ,
39
- names_file : str | epath .PathLike | None = None ,
39
+ names : Iterable [ str ] | None = None ,
40
+ names_file : epath .PathLike | None = None ,
40
41
doc : feature_lib .DocArg = None ,
41
42
):
42
43
"""Constructs a ClassLabel FeatureConnector.
@@ -51,11 +52,11 @@ def __init__(
51
52
Note: On python2, the strings are encoded as utf-8.
52
53
53
54
Args:
54
- num_classes: `int`, number of classes. All labels must be < num_classes.
55
- names: `list<str>`, string names for the integer classes. The order in
56
- which the names are provided is kept.
57
- names_file: `str` or `epath.PathLike`, path to a file with names for the
58
- integer classes, one per line.
55
+ num_classes: Number of classes. All labels must be < num_classes.
56
+ names: String names for the integer classes. The order in which the names
57
+ are provided is kept.
58
+ names_file: Path to a file with names for the integer classes, one per
59
+ line.
59
60
doc: Documentation of this feature (e.g. description).
60
61
"""
61
62
super (ClassLabel , self ).__init__ (shape = (), dtype = np .int64 , doc = doc )
@@ -76,7 +77,7 @@ def __init__(
76
77
if num_classes is not None :
77
78
self ._num_classes = num_classes
78
79
elif names is not None :
79
- self .names = names
80
+ self .names = list ( names )
80
81
elif names_file is not None :
81
82
self .names = _load_names_from_file (epath .Path (names_file ))
82
83
@@ -85,10 +86,10 @@ def num_classes(self) -> Optional[int]:
85
86
return self ._num_classes
86
87
87
88
@property
88
- def names (self ) -> Optional [ list [str ] ]:
89
+ def names (self ) -> list [str ]:
89
90
if not self ._int2str :
90
91
return [str (i ) for i in range (self ._num_classes )]
91
- return list ( self ._int2str )
92
+ return self ._int2str
92
93
93
94
@names .setter
94
95
def names (self , new_names : list [str ]):
@@ -224,5 +225,7 @@ def _load_names_from_file(names_filepath: epath.Path) -> list[str]:
224
225
]
225
226
226
227
227
- def _write_names_to_file (names_filepath : epath .Path , names ) -> None :
228
+ def _write_names_to_file (
229
+ names_filepath : epath .Path , names : Iterable [str ]
230
+ ) -> None :
228
231
names_filepath .write_text ("\n " .join (names ) + "\n " )
0 commit comments