Skip to content

Commit ec55480

Browse files
committed
#172 Update copied typing module for xarray
1 parent 3bfb255 commit ec55480

File tree

1 file changed

+61
-26
lines changed

1 file changed

+61
-26
lines changed

xarray_dataclasses/v2/typing.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["Attr", "Column", "Data", "Index", "Other"]
1+
__all__ = ["Attr", "Coord", "Coordof", "Data", "Dataof", "Other"]
22

33

44
# standard library
@@ -10,6 +10,7 @@
1010
Callable,
1111
Collection,
1212
Dict,
13+
Generic,
1314
Hashable,
1415
Iterable,
1516
Optional,
@@ -20,8 +21,8 @@
2021

2122

2223
# dependencies
23-
import pandas as pd
24-
from pandas.api.types import pandas_dtype
24+
import numpy as np
25+
import xarray as xr
2526
from typing_extensions import (
2627
Annotated,
2728
Literal,
@@ -34,12 +35,15 @@
3435

3536

3637
# type hints (private)
37-
Pandas = Union[pd.DataFrame, "pd.Series[Any]"]
3838
P = ParamSpec("P")
3939
T = TypeVar("T")
40-
TPandas = TypeVar("TPandas", bound=Pandas)
41-
TFrame = TypeVar("TFrame", bound=pd.DataFrame)
42-
TSeries = TypeVar("TSeries", bound="pd.Series[Any]")
40+
TDataClass = TypeVar("TDataClass", bound="DataClass[Any]")
41+
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
42+
TDataset = TypeVar("TDataset", bound=xr.Dataset)
43+
TDims = TypeVar("TDims")
44+
TDType = TypeVar("TDType")
45+
TXarray = TypeVar("TXarray", bound="Xarray")
46+
Xarray = Union[xr.DataArray, xr.Dataset]
4347

4448

4549
class DataClass(Protocol[P]):
@@ -51,31 +55,34 @@ def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
5155
...
5256

5357

54-
class PandasClass(Protocol[P, TPandas]):
55-
"""Type hint for dataclass objects with a pandas factory."""
58+
class XarrayClass(Protocol[P, TXarray]):
59+
"""Type hint for dataclass objects with a xarray factory."""
5660

5761
__dataclass_fields__: Dict[str, "Field[Any]"]
58-
__pandas_factory__: Callable[..., TPandas]
62+
__xarray_factory__: Callable[..., TXarray]
5963

6064
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
6165
...
6266

6367

68+
class Dims(Generic[TDims]):
69+
"""Empty class for storing type of dimensions."""
70+
71+
pass
72+
73+
6474
class Role(Enum):
6575
"""Annotations for typing dataclass fields."""
6676

6777
ATTR = auto()
6878
"""Annotation for attribute fields."""
6979

70-
COLUMN = auto()
71-
"""Annotation for column fields."""
80+
COORD = auto()
81+
"""Annotation for coordinate fields."""
7282

7383
DATA = auto()
7484
"""Annotation for data fields."""
7585

76-
INDEX = auto()
77-
"""Annotation for index fields."""
78-
7986
OTHER = auto()
8087
"""Annotation for other fields."""
8188

@@ -89,14 +96,17 @@ def annotates(cls, tp: Any) -> bool:
8996
Attr = Annotated[T, Role.ATTR]
9097
"""Type hint for attribute fields (``Attr[T]``)."""
9198

92-
Column = Annotated[T, Role.COLUMN]
93-
"""Type hint for column fields (``Column[T]``)."""
99+
Coord = Annotated[Union[Dims[TDims], Collection[TDType]], Role.COORD]
100+
"""Type hint for coordinate fields (``Coord[TDims, TDType]``)."""
94101

95-
Data = Annotated[Collection[T], Role.DATA]
96-
"""Type hint for data fields (``Data[T]``)."""
102+
Coordof = Annotated[TDataClass, Role.COORD]
103+
"""Type hint for coordinate fields (``Dataof[TDataClass]``)."""
97104

98-
Index = Annotated[Collection[T], Role.INDEX]
99-
"""Type hint for index fields (``Index[T]``)."""
105+
Data = Annotated[Union[Dims[TDims], Collection[TDType]], Role.DATA]
106+
"""Type hint for data fields (``Coord[TDims, TDType]``)."""
107+
108+
Dataof = Annotated[TDataClass, Role.DATA]
109+
"""Type hint for data fields (``Dataof[TDataClass]``)."""
100110

101111
Other = Annotated[T, Role.OTHER]
102112
"""Type hint for other fields (``Other[T]``)."""
@@ -139,10 +149,35 @@ def get_annotations(tp: Any) -> Tuple[Any, ...]:
139149
raise TypeError("Could not find any role-annotated type.")
140150

141151

152+
def get_dims(tp: Any) -> Optional[Tuple[str, ...]]:
153+
"""Extract dimensions if found or return None."""
154+
try:
155+
dims = get_args(get_args(get_annotated(tp))[0])[0]
156+
except (IndexError, TypeError):
157+
return None
158+
159+
args = get_args(dims)
160+
origin = get_origin(dims)
161+
162+
if args == () or args == ((),):
163+
return ()
164+
165+
if origin is Literal:
166+
return (str(args[0]),)
167+
168+
if not (origin is tuple or origin is Tuple):
169+
raise TypeError(f"Could not find any dims in {tp!r}.")
170+
171+
if not all(get_origin(arg) is Literal for arg in args):
172+
raise TypeError(f"Could not find any dims in {tp!r}.")
173+
174+
return tuple(str(get_args(arg)[0]) for arg in args)
175+
176+
142177
def get_dtype(tp: Any) -> Optional[str]:
143-
"""Extract a NumPy or pandas data type."""
178+
"""Extract a data type if found or return None."""
144179
try:
145-
dtype = get_args(get_annotated(tp))[0]
180+
dtype = get_args(get_args(get_annotated(tp))[1])[0]
146181
except (IndexError, TypeError):
147182
return None
148183

@@ -152,7 +187,7 @@ def get_dtype(tp: Any) -> Optional[str]:
152187
if get_origin(dtype) is Literal:
153188
dtype = get_args(dtype)[0]
154189

155-
return pandas_dtype(dtype).name
190+
return np.dtype(dtype).name
156191

157192

158193
def get_name(tp: Any, default: Hashable = None) -> Hashable:
@@ -170,12 +205,12 @@ def get_name(tp: Any, default: Hashable = None) -> Hashable:
170205
except TypeError:
171206
raise ValueError("Could not find any valid name.")
172207

173-
return name # type: ignore
208+
return name
174209

175210

176211
def get_role(tp: Any, default: Role = Role.OTHER) -> Role:
177212
"""Extract a role if found or return given default."""
178213
try:
179-
return get_annotations(tp)[0] # type: ignore
214+
return get_annotations(tp)[0]
180215
except TypeError:
181216
return default

0 commit comments

Comments
 (0)