Skip to content

Commit e687e34

Browse files
committed
#174 Copy specs module of pandas-dataclasses (v0.9.0)
1 parent 5c5d744 commit e687e34

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

xarray_dataclasses/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
__all__ = ["typing"]
1+
__all__ = ["specs", "typing"]
22

33

4+
from . import specs
45
from . import typing

xarray_dataclasses/v2/specs.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
__all__ = ["Spec"]
2+
3+
4+
# standard library
5+
from dataclasses import dataclass, replace
6+
from dataclasses import Field as Field_, fields as fields_
7+
from functools import lru_cache
8+
from typing import Any, Callable, Hashable, List, Optional, Type
9+
10+
11+
# dependencies
12+
from typing_extensions import Literal, get_type_hints
13+
14+
15+
# submodules
16+
from .typing import P, DataClass, Pandas, Role, get_dtype, get_name, get_role
17+
18+
19+
# runtime classes
20+
@dataclass(frozen=True)
21+
class Field:
22+
"""Specification of a field."""
23+
24+
id: str
25+
"""Identifier of the field."""
26+
27+
name: Hashable
28+
"""Name of the field."""
29+
30+
role: Literal["attr", "column", "data", "index"]
31+
"""Role of the field."""
32+
33+
type: Optional[Any]
34+
"""Type (hint) of the field data."""
35+
36+
dtype: Optional[str]
37+
"""Data type of the field data."""
38+
39+
default: Any
40+
"""Default value of the field data."""
41+
42+
def update(self, obj: DataClass[P]) -> "Field":
43+
"""Update the specification by a dataclass object."""
44+
return replace(
45+
self,
46+
name=format_name(self.name, obj),
47+
default=getattr(obj, self.id, self.default),
48+
)
49+
50+
51+
class Fields(List[Field]):
52+
"""List of field specifications (with selectors)."""
53+
54+
@property
55+
def of_attr(self) -> "Fields":
56+
"""Select only attribute field specifications."""
57+
return Fields(field for field in self if field.role == "attr")
58+
59+
@property
60+
def of_column(self) -> "Fields":
61+
"""Select only column field specifications."""
62+
return Fields(field for field in self if field.role == "column")
63+
64+
@property
65+
def of_data(self) -> "Fields":
66+
"""Select only data field specifications."""
67+
return Fields(field for field in self if field.role == "data")
68+
69+
@property
70+
def of_index(self) -> "Fields":
71+
"""Select only index field specifications."""
72+
return Fields(field for field in self if field.role == "index")
73+
74+
def update(self, obj: DataClass[P]) -> "Fields":
75+
"""Update the specifications by a dataclass object."""
76+
return Fields(field.update(obj) for field in self)
77+
78+
79+
@dataclass(frozen=True)
80+
class Spec:
81+
"""Specification of a pandas dataclass."""
82+
83+
fields: Fields
84+
"""List of field specifications."""
85+
86+
factory: Optional[Callable[..., Pandas]] = None
87+
"""Factory for pandas data creation."""
88+
89+
@classmethod
90+
def from_dataclass(cls, dataclass: Type[DataClass[P]]) -> "Spec":
91+
"""Create a specification from a data class."""
92+
fields = Fields()
93+
94+
for field_ in fields_(eval_types(dataclass)):
95+
field = convert_field(field_)
96+
97+
if field is not None:
98+
fields.append(field)
99+
100+
factory = getattr(dataclass, "__pandas_factory__", None)
101+
return cls(fields, factory)
102+
103+
def update(self, obj: DataClass[P]) -> "Spec":
104+
"""Update the specification by a dataclass object."""
105+
return replace(self, fields=self.fields.update(obj))
106+
107+
def __matmul__(self, obj: DataClass[P]) -> "Spec":
108+
"""Alias of the update method."""
109+
return self.update(obj)
110+
111+
112+
# runtime functions
113+
@lru_cache(maxsize=None)
114+
def convert_field(field_: "Field_[Any]") -> Optional[Field]:
115+
"""Convert a dataclass field to a field specification."""
116+
role = get_role(field_.type)
117+
118+
if role is Role.OTHER:
119+
return None
120+
121+
return Field(
122+
id=field_.name,
123+
name=get_name(field_.type, field_.name),
124+
role=role.name.lower(), # type: ignore
125+
type=field_.type,
126+
dtype=get_dtype(field_.type),
127+
default=field_.default,
128+
)
129+
130+
131+
@lru_cache(maxsize=None)
132+
def eval_types(dataclass: Type[DataClass[P]]) -> Type[DataClass[P]]:
133+
"""Evaluate field types of a dataclass."""
134+
types = get_type_hints(dataclass, include_extras=True)
135+
136+
for field_ in fields_(dataclass):
137+
field_.type = types[field_.name]
138+
139+
return dataclass
140+
141+
142+
def format_name(name: Hashable, obj: DataClass[P]) -> Hashable:
143+
"""Format a name by a dataclass object."""
144+
if isinstance(name, tuple):
145+
return type(name)(format_name(elem, obj) for elem in name)
146+
147+
if isinstance(name, str):
148+
return name.format(obj)
149+
150+
return name

0 commit comments

Comments
 (0)