1
- __all__ = ["Attr" , "Column " , "Data" , "Index " , "Other" ]
1
+ __all__ = ["Attr" , "Coord " , "Coordof" , " Data" , "Dataof " , "Other" ]
2
2
3
3
4
4
# standard library
10
10
Callable ,
11
11
Collection ,
12
12
Dict ,
13
+ Generic ,
13
14
Hashable ,
14
15
Iterable ,
15
16
Optional ,
20
21
21
22
22
23
# dependencies
23
- import pandas as pd
24
- from pandas . api . types import pandas_dtype
24
+ import numpy as np
25
+ import xarray as xr
25
26
from typing_extensions import (
26
27
Annotated ,
27
28
Literal ,
34
35
35
36
36
37
# type hints (private)
37
- Pandas = Union [pd .DataFrame , "pd.Series[Any]" ]
38
38
P = ParamSpec ("P" )
39
39
T = TypeVar ("T" )
40
- TPandas = TypeVar ("TPandas" , bound = Pandas )
41
- TFrame = TypeVar ("TFrame" , bound = pd .DataFrame )
42
- TSeries = TypeVar ("TSeries" , bound = "pd.Series[Any]" )
40
+ TDataClass = TypeVar ("TDataClass" , bound = "DataClass[Any]" )
41
+ TDataArray = TypeVar ("TDataArray" , bound = xr .DataArray )
42
+ TDataset = TypeVar ("TDataset" , bound = xr .Dataset )
43
+ TDims = TypeVar ("TDims" )
44
+ TDType = TypeVar ("TDType" )
45
+ TXarray = TypeVar ("TXarray" , bound = "Xarray" )
46
+ Xarray = Union [xr .DataArray , xr .Dataset ]
43
47
44
48
45
49
class DataClass (Protocol [P ]):
@@ -51,31 +55,34 @@ def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
51
55
...
52
56
53
57
54
- class PandasClass (Protocol [P , TPandas ]):
55
- """Type hint for dataclass objects with a pandas factory."""
58
+ class XarrayClass (Protocol [P , TXarray ]):
59
+ """Type hint for dataclass objects with a xarray factory."""
56
60
57
61
__dataclass_fields__ : Dict [str , "Field[Any]" ]
58
- __pandas_factory__ : Callable [..., TPandas ]
62
+ __xarray_factory__ : Callable [..., TXarray ]
59
63
60
64
def __init__ (self , * args : P .args , ** kwargs : P .kwargs ) -> None :
61
65
...
62
66
63
67
68
+ class Dims (Generic [TDims ]):
69
+ """Empty class for storing type of dimensions."""
70
+
71
+ pass
72
+
73
+
64
74
class Role (Enum ):
65
75
"""Annotations for typing dataclass fields."""
66
76
67
77
ATTR = auto ()
68
78
"""Annotation for attribute fields."""
69
79
70
- COLUMN = auto ()
71
- """Annotation for column fields."""
80
+ COORD = auto ()
81
+ """Annotation for coordinate fields."""
72
82
73
83
DATA = auto ()
74
84
"""Annotation for data fields."""
75
85
76
- INDEX = auto ()
77
- """Annotation for index fields."""
78
-
79
86
OTHER = auto ()
80
87
"""Annotation for other fields."""
81
88
@@ -89,14 +96,17 @@ def annotates(cls, tp: Any) -> bool:
89
96
Attr = Annotated [T , Role .ATTR ]
90
97
"""Type hint for attribute fields (``Attr[T]``)."""
91
98
92
- Column = Annotated [T , Role .COLUMN ]
93
- """Type hint for column fields (``Column[T ]``)."""
99
+ Coord = Annotated [Union [ Dims [ TDims ], Collection [ TDType ]], Role .COORD ]
100
+ """Type hint for coordinate fields (``Coord[TDims, TDType ]``)."""
94
101
95
- Data = Annotated [Collection [ T ] , Role .DATA ]
96
- """Type hint for data fields (``Data[T ]``)."""
102
+ Coordof = Annotated [TDataClass , Role .COORD ]
103
+ """Type hint for coordinate fields (``Dataof[TDataClass ]``)."""
97
104
98
- Index = Annotated [Collection [T ], Role .INDEX ]
99
- """Type hint for index fields (``Index[T]``)."""
105
+ Data = Annotated [Union [Dims [TDims ], Collection [TDType ]], Role .DATA ]
106
+ """Type hint for data fields (``Coord[TDims, TDType]``)."""
107
+
108
+ Dataof = Annotated [TDataClass , Role .DATA ]
109
+ """Type hint for data fields (``Dataof[TDataClass]``)."""
100
110
101
111
Other = Annotated [T , Role .OTHER ]
102
112
"""Type hint for other fields (``Other[T]``)."""
@@ -139,10 +149,35 @@ def get_annotations(tp: Any) -> Tuple[Any, ...]:
139
149
raise TypeError ("Could not find any role-annotated type." )
140
150
141
151
152
+ def get_dims (tp : Any ) -> Optional [Tuple [str , ...]]:
153
+ """Extract dimensions if found or return None."""
154
+ try :
155
+ dims = get_args (get_args (get_annotated (tp ))[0 ])[0 ]
156
+ except (IndexError , TypeError ):
157
+ return None
158
+
159
+ args = get_args (dims )
160
+ origin = get_origin (dims )
161
+
162
+ if args == () or args == ((),):
163
+ return ()
164
+
165
+ if origin is Literal :
166
+ return (str (args [0 ]),)
167
+
168
+ if not (origin is tuple or origin is Tuple ):
169
+ raise TypeError (f"Could not find any dims in { tp !r} ." )
170
+
171
+ if not all (get_origin (arg ) is Literal for arg in args ):
172
+ raise TypeError (f"Could not find any dims in { tp !r} ." )
173
+
174
+ return tuple (str (get_args (arg )[0 ]) for arg in args )
175
+
176
+
142
177
def get_dtype (tp : Any ) -> Optional [str ]:
143
- """Extract a NumPy or pandas data type ."""
178
+ """Extract a data type if found or return None ."""
144
179
try :
145
- dtype = get_args (get_annotated (tp ))[0 ]
180
+ dtype = get_args (get_args ( get_annotated (tp ))[ 1 ] )[0 ]
146
181
except (IndexError , TypeError ):
147
182
return None
148
183
@@ -152,7 +187,7 @@ def get_dtype(tp: Any) -> Optional[str]:
152
187
if get_origin (dtype ) is Literal :
153
188
dtype = get_args (dtype )[0 ]
154
189
155
- return pandas_dtype (dtype ).name
190
+ return np . dtype (dtype ).name
156
191
157
192
158
193
def get_name (tp : Any , default : Hashable = None ) -> Hashable :
@@ -170,12 +205,12 @@ def get_name(tp: Any, default: Hashable = None) -> Hashable:
170
205
except TypeError :
171
206
raise ValueError ("Could not find any valid name." )
172
207
173
- return name # type: ignore
208
+ return name
174
209
175
210
176
211
def get_role (tp : Any , default : Role = Role .OTHER ) -> Role :
177
212
"""Extract a role if found or return given default."""
178
213
try :
179
- return get_annotations (tp )[0 ] # type: ignore
214
+ return get_annotations (tp )[0 ]
180
215
except TypeError :
181
216
return default
0 commit comments