Skip to content

Commit 18571af

Browse files
committed
#141 Add labeled collection
1 parent 364216e commit 18571af

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

tests/test_typing.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
# submodules
1111
from xarray_dataclasses.typing import (
12-
ArrayLike,
1312
Attr,
13+
Collection,
1414
Coord,
1515
Data,
1616
Name,
@@ -34,21 +34,21 @@
3434
(Tuple[()], ()),
3535
(Tuple[X], ("x",)),
3636
(Tuple[X, Y], ("x", "y")),
37-
(ArrayLike[X, Any], ("x",)),
38-
(ArrayLike[Tuple[()], Any], ()),
39-
(ArrayLike[Tuple[X], Any], ("x",)),
40-
(ArrayLike[Tuple[X, Y], Any], ("x", "y")),
37+
(Collection[X, Any], ("x",)),
38+
(Collection[Tuple[()], Any], ()),
39+
(Collection[Tuple[X], Any], ("x",)),
40+
(Collection[Tuple[X, Y], Any], ("x", "y")),
4141
]
4242

4343
testdata_dtype = [
4444
(Any, None),
4545
(NoneType, None),
4646
(Int64, "int64"),
4747
(int, "int"),
48-
(ArrayLike[Any, Any], None),
49-
(ArrayLike[Any, NoneType], None),
50-
(ArrayLike[Any, Int64], "int64"),
51-
(ArrayLike[Any, int], "int"),
48+
(Collection[Any, Any], None),
49+
(Collection[Any, NoneType], None),
50+
(Collection[Any, Int64], "int64"),
51+
(Collection[Any, int], "int"),
5252
]
5353

5454
testdata_field_type = [

xarray_dataclasses/typing.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import (
2424
Any,
2525
ClassVar,
26+
Collection,
2627
Dict,
2728
Hashable,
2829
Optional,
@@ -65,6 +66,18 @@
6566
Sizes = Dict[str, int]
6667

6768

69+
class Labeled(Protocol[TDims]):
70+
"""Type hint for labeled objects."""
71+
72+
pass
73+
74+
75+
class Collection(Labeled[TDims], Collection[TDtype], Protocol):
76+
"""Type hint for labeled collection objects."""
77+
78+
pass
79+
80+
6881
@runtime_checkable
6982
class ArrayLike(Protocol[TDims, TDtype]):
7083
"""Type hint for array-like objects."""
@@ -138,7 +151,7 @@ class Image(AsDataArray):
138151
139152
"""
140153

141-
Coord = Annotated[Union[ArrayLike[TDims, TDtype], TDtype], FieldType.COORD]
154+
Coord = Annotated[Union[Collection[TDims, TDtype], TDtype], FieldType.COORD]
142155
"""Type hint to define coordinate fields (``Coord[TDims, TDtype]``).
143156
144157
Example:
@@ -189,7 +202,7 @@ class Image(AsDataArray):
189202
190203
"""
191204

192-
Data = Annotated[Union[ArrayLike[TDims, TDtype], TDtype], FieldType.DATA]
205+
Data = Annotated[Union[Collection[TDims, TDtype], TDtype], FieldType.DATA]
193206
"""Type hint to define data fields (``Coordof[TDims, TDtype]``).
194207
195208
Examples:
@@ -267,7 +280,7 @@ def get_dims(type_: Any) -> Dims:
267280
args = get_args(type_)
268281
origin = get_origin(type_)
269282

270-
if origin is ArrayLike:
283+
if origin is Collection:
271284
return get_dims(args[0])
272285

273286
if origin is tuple or origin is Tuple:
@@ -298,7 +311,7 @@ def get_dtype(type_: Any) -> Dtype:
298311
args = get_args(type_)
299312
origin = get_origin(type_)
300313

301-
if origin is ArrayLike:
314+
if origin is Collection:
302315
return get_dtype(args[1])
303316

304317
if origin is Literal:

0 commit comments

Comments
 (0)