Skip to content

Commit 3bfb255

Browse files
committed
#172 Copy typing module of pandas-dataclasses (v0.9.0)
1 parent e1e259f commit 3bfb255

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

xarray_dataclasses/v2/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__all__ = ["typing"]
2+
3+
4+
from . import typing

xarray_dataclasses/v2/typing.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
__all__ = ["Attr", "Column", "Data", "Index", "Other"]
2+
3+
4+
# standard library
5+
from dataclasses import Field
6+
from enum import Enum, auto
7+
from itertools import chain
8+
from typing import (
9+
Any,
10+
Callable,
11+
Collection,
12+
Dict,
13+
Hashable,
14+
Iterable,
15+
Optional,
16+
Tuple,
17+
TypeVar,
18+
Union,
19+
)
20+
21+
22+
# dependencies
23+
import pandas as pd
24+
from pandas.api.types import pandas_dtype
25+
from typing_extensions import (
26+
Annotated,
27+
Literal,
28+
ParamSpec,
29+
Protocol,
30+
get_args,
31+
get_origin,
32+
get_type_hints,
33+
)
34+
35+
36+
# type hints (private)
37+
Pandas = Union[pd.DataFrame, "pd.Series[Any]"]
38+
P = ParamSpec("P")
39+
T = TypeVar("T")
40+
TPandas = TypeVar("TPandas", bound=Pandas)
41+
TFrame = TypeVar("TFrame", bound=pd.DataFrame)
42+
TSeries = TypeVar("TSeries", bound="pd.Series[Any]")
43+
44+
45+
class DataClass(Protocol[P]):
46+
"""Type hint for dataclass objects."""
47+
48+
__dataclass_fields__: Dict[str, "Field[Any]"]
49+
50+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
51+
...
52+
53+
54+
class PandasClass(Protocol[P, TPandas]):
55+
"""Type hint for dataclass objects with a pandas factory."""
56+
57+
__dataclass_fields__: Dict[str, "Field[Any]"]
58+
__pandas_factory__: Callable[..., TPandas]
59+
60+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
61+
...
62+
63+
64+
class Role(Enum):
65+
"""Annotations for typing dataclass fields."""
66+
67+
ATTR = auto()
68+
"""Annotation for attribute fields."""
69+
70+
COLUMN = auto()
71+
"""Annotation for column fields."""
72+
73+
DATA = auto()
74+
"""Annotation for data fields."""
75+
76+
INDEX = auto()
77+
"""Annotation for index fields."""
78+
79+
OTHER = auto()
80+
"""Annotation for other fields."""
81+
82+
@classmethod
83+
def annotates(cls, tp: Any) -> bool:
84+
"""Check if any role annotates a type hint."""
85+
return any(isinstance(arg, cls) for arg in get_args(tp))
86+
87+
88+
# type hints (public)
89+
Attr = Annotated[T, Role.ATTR]
90+
"""Type hint for attribute fields (``Attr[T]``)."""
91+
92+
Column = Annotated[T, Role.COLUMN]
93+
"""Type hint for column fields (``Column[T]``)."""
94+
95+
Data = Annotated[Collection[T], Role.DATA]
96+
"""Type hint for data fields (``Data[T]``)."""
97+
98+
Index = Annotated[Collection[T], Role.INDEX]
99+
"""Type hint for index fields (``Index[T]``)."""
100+
101+
Other = Annotated[T, Role.OTHER]
102+
"""Type hint for other fields (``Other[T]``)."""
103+
104+
105+
# runtime functions
106+
def deannotate(tp: Any) -> Any:
107+
"""Recursively remove annotations in a type hint."""
108+
109+
class Temporary:
110+
__annotations__ = dict(tp=tp)
111+
112+
return get_type_hints(Temporary)["tp"]
113+
114+
115+
def find_annotated(tp: Any) -> Iterable[Any]:
116+
"""Generate all annotated types in a type hint."""
117+
args = get_args(tp)
118+
119+
if get_origin(tp) is Annotated:
120+
yield tp
121+
yield from find_annotated(args[0])
122+
else:
123+
yield from chain(*map(find_annotated, args))
124+
125+
126+
def get_annotated(tp: Any) -> Any:
127+
"""Extract the first role-annotated type."""
128+
for annotated in filter(Role.annotates, find_annotated(tp)):
129+
return deannotate(annotated)
130+
131+
raise TypeError("Could not find any role-annotated type.")
132+
133+
134+
def get_annotations(tp: Any) -> Tuple[Any, ...]:
135+
"""Extract annotations of the first role-annotated type."""
136+
for annotated in filter(Role.annotates, find_annotated(tp)):
137+
return get_args(annotated)[1:]
138+
139+
raise TypeError("Could not find any role-annotated type.")
140+
141+
142+
def get_dtype(tp: Any) -> Optional[str]:
143+
"""Extract a NumPy or pandas data type."""
144+
try:
145+
dtype = get_args(get_annotated(tp))[0]
146+
except (IndexError, TypeError):
147+
return None
148+
149+
if dtype is Any or dtype is type(None):
150+
return None
151+
152+
if get_origin(dtype) is Literal:
153+
dtype = get_args(dtype)[0]
154+
155+
return pandas_dtype(dtype).name
156+
157+
158+
def get_name(tp: Any, default: Hashable = None) -> Hashable:
159+
"""Extract a name if found or return given default."""
160+
try:
161+
name = get_annotations(tp)[1]
162+
except (IndexError, TypeError):
163+
return default
164+
165+
if name is Ellipsis:
166+
return default
167+
168+
try:
169+
hash(name)
170+
except TypeError:
171+
raise ValueError("Could not find any valid name.")
172+
173+
return name # type: ignore
174+
175+
176+
def get_role(tp: Any, default: Role = Role.OTHER) -> Role:
177+
"""Extract a role if found or return given default."""
178+
try:
179+
return get_annotations(tp)[0] # type: ignore
180+
except TypeError:
181+
return default

0 commit comments

Comments
 (0)