Skip to content

Commit 5c5d744

Browse files
authored
#173 Merge pull request from astropenguin/astropenguin/issue172
Add v2 typing module
2 parents e1e259f + 27b528b commit 5c5d744

File tree

3 files changed

+359
-0
lines changed

3 files changed

+359
-0
lines changed

tests/test_v2_typing.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# standard library
2+
from dataclasses import dataclass
3+
from typing import Any, Tuple, Union
4+
5+
6+
# dependencies
7+
import numpy as np
8+
from xarray_dataclasses.v2.typing import (
9+
Attr,
10+
Coord,
11+
Coordof,
12+
Data,
13+
Dataof,
14+
Role,
15+
get_dims,
16+
get_dtype,
17+
get_name,
18+
get_role,
19+
)
20+
from pytest import mark
21+
from typing_extensions import Annotated as Ann, Literal as L
22+
23+
24+
# test data
25+
@dataclass
26+
class DataClass:
27+
data: Any
28+
29+
30+
testdata_dims = [
31+
(Coord[Tuple[()], Any], ()),
32+
(Coord[L["x"], Any], ("x",)),
33+
(Coord[Tuple[L["x"]], Any], ("x",)),
34+
(Coord[Tuple[L["x"], L["y"]], Any], ("x", "y")),
35+
(Coordof[DataClass], None),
36+
(Data[Tuple[()], Any], ()),
37+
(Data[L["x"], Any], ("x",)),
38+
(Data[Tuple[L["x"]], Any], ("x",)),
39+
(Data[Tuple[L["x"], L["y"]], Any], ("x", "y")),
40+
(Dataof[DataClass], None),
41+
(Ann[Coord[L["x"], Any], "coord"], ("x",)),
42+
(Ann[Coordof[DataClass], "coord"], None),
43+
(Ann[Data[L["x"], Any], "data"], ("x",)),
44+
(Ann[Dataof[DataClass], "data"], None),
45+
(Union[Ann[Coord[L["x"], Any], "coord"], Ann[Any, "any"]], ("x",)),
46+
(Union[Ann[Coordof[DataClass], "coord"], Ann[Any, "any"]], None),
47+
(Union[Ann[Data[L["x"], Any], "data"], Ann[Any, "any"]], ("x",)),
48+
(Union[Ann[Dataof[DataClass], "data"], Ann[Any, "any"]], None),
49+
]
50+
51+
testdata_dtype = [
52+
(Coord[Any, Any], None),
53+
(Coord[Any, None], None),
54+
(Coord[Any, int], np.dtype("i8")),
55+
(Coord[Any, L["i8"]], np.dtype("i8")),
56+
(Coordof[DataClass], None),
57+
(Data[Any, Any], None),
58+
(Data[Any, None], None),
59+
(Data[Any, int], np.dtype("i8")),
60+
(Data[Any, L["i8"]], np.dtype("i8")),
61+
(Dataof[DataClass], None),
62+
(Ann[Coord[Any, float], "coord"], np.dtype("f8")),
63+
(Ann[Coordof[DataClass], "coord"], None),
64+
(Ann[Data[Any, float], "data"], np.dtype("f8")),
65+
(Ann[Dataof[DataClass], "data"], None),
66+
(Union[Ann[Coord[Any, float], "coord"], Ann[Any, "any"]], np.dtype("f8")),
67+
(Union[Ann[Coordof[DataClass], "coord"], Ann[Any, "any"]], None),
68+
(Union[Ann[Data[Any, float], "data"], Ann[Any, "any"]], np.dtype("f8")),
69+
(Union[Ann[Dataof[DataClass], "data"], Ann[Any, "any"]], None),
70+
]
71+
72+
testdata_name = [
73+
(Attr[Any], None),
74+
(Coord[Any, Any], None),
75+
(Coordof[DataClass], None),
76+
(Data[Any, Any], None),
77+
(Dataof[DataClass], None),
78+
(Any, None),
79+
(Ann[Attr[Any], "attr"], "attr"),
80+
(Ann[Coord[Any, Any], "coord"], "coord"),
81+
(Ann[Coordof[DataClass], "coord"], "coord"),
82+
(Ann[Data[Any, Any], "data"], "data"),
83+
(Ann[Dataof[DataClass], "data"], "data"),
84+
(Ann[Any, "other"], None),
85+
(Ann[Attr[Any], ..., "attr"], None),
86+
(Ann[Coord[Any, Any], ..., "coord"], None),
87+
(Ann[Coordof[DataClass], ..., "coord"], None),
88+
(Ann[Data[Any, Any], ..., "data"], None),
89+
(Ann[Dataof[DataClass], ..., "data"], None),
90+
(Ann[Any, ..., "other"], None),
91+
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], "attr"),
92+
(Union[Ann[Coord[Any, Any], "coord"], Ann[Any, "any"]], "coord"),
93+
(Union[Ann[Coordof[DataClass], "coord"], Ann[Any, "any"]], "coord"),
94+
(Union[Ann[Data[Any, Any], "data"], Ann[Any, "any"]], "data"),
95+
(Union[Ann[Dataof[DataClass], "data"], Ann[Any, "any"]], "data"),
96+
(Union[Ann[Any, "other"], Ann[Any, "any"]], None),
97+
]
98+
99+
testdata_role = [
100+
(Attr[Any], Role.ATTR),
101+
(Coord[Any, Any], Role.COORD),
102+
(Coordof[DataClass], Role.COORD),
103+
(Data[Any, Any], Role.DATA),
104+
(Dataof[DataClass], Role.DATA),
105+
(Any, Role.OTHER),
106+
(Ann[Attr[Any], "attr"], Role.ATTR),
107+
(Ann[Coord[Any, Any], "coord"], Role.COORD),
108+
(Ann[Coordof[DataClass], "coord"], Role.COORD),
109+
(Ann[Data[Any, Any], "data"], Role.DATA),
110+
(Ann[Dataof[DataClass], "data"], Role.DATA),
111+
(Ann[Any, "other"], Role.OTHER),
112+
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], Role.ATTR),
113+
(Union[Ann[Coord[Any, Any], "coord"], Ann[Any, "any"]], Role.COORD),
114+
(Union[Ann[Coordof[DataClass], "coord"], Ann[Any, "any"]], Role.COORD),
115+
(Union[Ann[Data[Any, Any], "data"], Ann[Any, "any"]], Role.DATA),
116+
(Union[Ann[Dataof[DataClass], "data"], Ann[Any, "any"]], Role.DATA),
117+
(Union[Ann[Any, "other"], Ann[Any, "any"]], Role.OTHER),
118+
]
119+
120+
121+
# test functions
122+
@mark.parametrize("tp, dims", testdata_dims)
123+
def test_get_dims(tp: Any, dims: Any) -> None:
124+
assert get_dims(tp) == dims
125+
126+
127+
@mark.parametrize("tp, dtype", testdata_dtype)
128+
def test_get_dtype(tp: Any, dtype: Any) -> None:
129+
assert get_dtype(tp) == dtype
130+
131+
132+
@mark.parametrize("tp, name", testdata_name)
133+
def test_get_name(tp: Any, name: Any) -> None:
134+
assert get_name(tp) == name
135+
136+
137+
@mark.parametrize("tp, role", testdata_role)
138+
def test_get_role(tp: Any, role: Any) -> None:
139+
assert get_role(tp) is role

xarray_dataclasses/v2/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__all__ = ["typing"]
2+
3+
4+
from . import typing

xarray_dataclasses/v2/typing.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
__all__ = ["Attr", "Coord", "Coordof", "Data", "Dataof", "Other"]
2+
3+
4+
# standard library
5+
from dataclasses import Field
6+
from enum import Enum, auto
7+
from itertools import chain
8+
from typing import (
9+
Any,
10+
Callable,
11+
Collection,
12+
Dict,
13+
Generic,
14+
Hashable,
15+
Iterable,
16+
Optional,
17+
Tuple,
18+
TypeVar,
19+
Union,
20+
)
21+
22+
23+
# dependencies
24+
import numpy as np
25+
import xarray as xr
26+
from typing_extensions import (
27+
Annotated,
28+
Literal,
29+
ParamSpec,
30+
Protocol,
31+
get_args,
32+
get_origin,
33+
get_type_hints,
34+
)
35+
36+
37+
# type hints (private)
38+
P = ParamSpec("P")
39+
T = TypeVar("T")
40+
TDataClass = TypeVar("TDataClass", bound="DataClass[Any]")
41+
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
42+
TDataset = TypeVar("TDataset", bound=xr.Dataset)
43+
TDims = TypeVar("TDims")
44+
TDType = TypeVar("TDType")
45+
TXarray = TypeVar("TXarray", bound="Xarray")
46+
Xarray = Union[xr.DataArray, xr.Dataset]
47+
48+
49+
class DataClass(Protocol[P]):
50+
"""Type hint for dataclass objects."""
51+
52+
__dataclass_fields__: Dict[str, "Field[Any]"]
53+
54+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
55+
...
56+
57+
58+
class XarrayClass(Protocol[P, TXarray]):
59+
"""Type hint for dataclass objects with a xarray factory."""
60+
61+
__dataclass_fields__: Dict[str, "Field[Any]"]
62+
__xarray_factory__: Callable[..., TXarray]
63+
64+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
65+
...
66+
67+
68+
class Dims(Generic[TDims]):
69+
"""Empty class for storing type of dimensions."""
70+
71+
pass
72+
73+
74+
class Role(Enum):
75+
"""Annotations for typing dataclass fields."""
76+
77+
ATTR = auto()
78+
"""Annotation for attribute fields."""
79+
80+
COORD = auto()
81+
"""Annotation for coordinate fields."""
82+
83+
DATA = auto()
84+
"""Annotation for data fields."""
85+
86+
OTHER = auto()
87+
"""Annotation for other fields."""
88+
89+
@classmethod
90+
def annotates(cls, tp: Any) -> bool:
91+
"""Check if any role annotates a type hint."""
92+
return any(isinstance(arg, cls) for arg in get_args(tp))
93+
94+
95+
# type hints (public)
96+
Attr = Annotated[T, Role.ATTR]
97+
"""Type hint for attribute fields (``Attr[T]``)."""
98+
99+
Coord = Annotated[Union[Dims[TDims], Collection[TDType]], Role.COORD]
100+
"""Type hint for coordinate fields (``Coord[TDims, TDType]``)."""
101+
102+
Coordof = Annotated[TDataClass, Role.COORD]
103+
"""Type hint for coordinate fields (``Dataof[TDataClass]``)."""
104+
105+
Data = Annotated[Union[Dims[TDims], Collection[TDType]], Role.DATA]
106+
"""Type hint for data fields (``Coord[TDims, TDType]``)."""
107+
108+
Dataof = Annotated[TDataClass, Role.DATA]
109+
"""Type hint for data fields (``Dataof[TDataClass]``)."""
110+
111+
Other = Annotated[T, Role.OTHER]
112+
"""Type hint for other fields (``Other[T]``)."""
113+
114+
115+
# runtime functions
116+
def deannotate(tp: Any) -> Any:
117+
"""Recursively remove annotations in a type hint."""
118+
119+
class Temporary:
120+
__annotations__ = dict(tp=tp)
121+
122+
return get_type_hints(Temporary)["tp"]
123+
124+
125+
def find_annotated(tp: Any) -> Iterable[Any]:
126+
"""Generate all annotated types in a type hint."""
127+
args = get_args(tp)
128+
129+
if get_origin(tp) is Annotated:
130+
yield tp
131+
yield from find_annotated(args[0])
132+
else:
133+
yield from chain(*map(find_annotated, args))
134+
135+
136+
def get_annotated(tp: Any) -> Any:
137+
"""Extract the first role-annotated type."""
138+
for annotated in filter(Role.annotates, find_annotated(tp)):
139+
return deannotate(annotated)
140+
141+
raise TypeError("Could not find any role-annotated type.")
142+
143+
144+
def get_annotations(tp: Any) -> Tuple[Any, ...]:
145+
"""Extract annotations of the first role-annotated type."""
146+
for annotated in filter(Role.annotates, find_annotated(tp)):
147+
return get_args(annotated)[1:]
148+
149+
raise TypeError("Could not find any role-annotated type.")
150+
151+
152+
def get_dims(tp: Any) -> Optional[Tuple[str, ...]]:
153+
"""Extract dimensions if found or return None."""
154+
try:
155+
dims = get_args(get_args(get_annotated(tp))[0])[0]
156+
except (IndexError, TypeError):
157+
return None
158+
159+
args = get_args(dims)
160+
origin = get_origin(dims)
161+
162+
if args == () or args == ((),):
163+
return ()
164+
165+
if origin is Literal:
166+
return (str(args[0]),)
167+
168+
if not (origin is tuple or origin is Tuple):
169+
raise TypeError(f"Could not find any dims in {tp!r}.")
170+
171+
if not all(get_origin(arg) is Literal for arg in args):
172+
raise TypeError(f"Could not find any dims in {tp!r}.")
173+
174+
return tuple(str(get_args(arg)[0]) for arg in args)
175+
176+
177+
def get_dtype(tp: Any) -> Optional[str]:
178+
"""Extract a data type if found or return None."""
179+
try:
180+
dtype = get_args(get_args(get_annotated(tp))[1])[0]
181+
except (IndexError, TypeError):
182+
return None
183+
184+
if dtype is Any or dtype is type(None):
185+
return None
186+
187+
if get_origin(dtype) is Literal:
188+
dtype = get_args(dtype)[0]
189+
190+
return np.dtype(dtype).name
191+
192+
193+
def get_name(tp: Any, default: Hashable = None) -> Hashable:
194+
"""Extract a name if found or return given default."""
195+
try:
196+
name = get_annotations(tp)[1]
197+
except (IndexError, TypeError):
198+
return default
199+
200+
if name is Ellipsis:
201+
return default
202+
203+
try:
204+
hash(name)
205+
except TypeError:
206+
raise ValueError("Could not find any valid name.")
207+
208+
return name
209+
210+
211+
def get_role(tp: Any, default: Role = Role.OTHER) -> Role:
212+
"""Extract a role if found or return given default."""
213+
try:
214+
return get_annotations(tp)[0]
215+
except TypeError:
216+
return default

0 commit comments

Comments
 (0)