Skip to content

Commit d5fc5cf

Browse files
committed
#156 Update classmethod and runtime functions
1 parent acd0c2e commit d5fc5cf

File tree

1 file changed

+99
-7
lines changed

1 file changed

+99
-7
lines changed

xarray_dataclasses/specs.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,36 @@
1-
__all__ = ["DataSpec", "DataOptions"]
1+
__all__ = ["DataOptions", "DataSpec"]
22

33

44
# standard library
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass, field, fields
6+
from functools import lru_cache
67
from typing import Any, Dict, Generic, Hashable, Optional, Type, TypeVar
78

89

910
# dependencies
10-
from typing_extensions import Literal, TypeAlias
11+
from typing_extensions import Literal, TypeAlias, get_type_hints
1112

1213

1314
# submodules
14-
from .typing import AnyDType, AnyXarray, DataClass, Dims
15+
from .typing import (
16+
AnyDType,
17+
AnyField,
18+
AnyXarray,
19+
DataClass,
20+
Dims,
21+
Role,
22+
get_annotated,
23+
get_dataclass,
24+
get_dims,
25+
get_dtype,
26+
get_name,
27+
get_role,
28+
)
1529

1630

1731
# type hints
1832
AnySpec: TypeAlias = "ArraySpec | ScalarSpec"
33+
TDataClass = TypeVar("TDataClass", bound=DataClass[...])
1934
TReturn = TypeVar("TReturn", AnyXarray, None)
2035

2136

@@ -33,14 +48,31 @@ class ArraySpec:
3348
default: Any
3449
"""Default value of the array."""
3550

36-
dims: Dims
51+
dims: Dims = ()
3752
"""Dimensions of the array."""
3853

39-
type: Optional[AnyDType]
54+
type: Optional[AnyDType] = None
4055
"""Data type of the array."""
4156

4257
origin: Optional[Type[DataClass[Any]]] = None
43-
"""Dataclass of dims and type origins."""
58+
"""Dataclass as origins of name, dims, and type."""
59+
60+
def __post_init__(self) -> None:
61+
"""Update name, dims, and type if origin exists."""
62+
if self.origin is None:
63+
return
64+
65+
dataspec = DataSpec.from_dataclass(self.origin)
66+
setattr = object.__setattr__
67+
68+
for spec in dataspec.specs.of_data.values():
69+
setattr(self, "dims", spec.dims)
70+
setattr(self, "type", spec.type)
71+
break
72+
73+
for spec in dataspec.specs.of_name.values():
74+
setattr(self, "name", spec.default)
75+
break
4476

4577

4678
@dataclass(frozen=True)
@@ -101,3 +133,63 @@ class DataSpec:
101133

102134
options: DataOptions[Any] = DataOptions(type(None))
103135
"""Options for xarray data creation."""
136+
137+
@classmethod
138+
def from_dataclass(cls, dataclass: Type[DataClass[...]]) -> "DataSpec":
139+
"""Create a data specification from a dataclass."""
140+
specs = SpecDict()
141+
142+
for field in fields(eval_fields(dataclass)):
143+
spec = get_spec(field)
144+
145+
if spec is not None:
146+
specs[field.name] = spec
147+
148+
try:
149+
return cls(specs, dataclass.__dataoptions__) # type: ignore
150+
except AttributeError:
151+
return cls(specs)
152+
153+
154+
# runtime functions
155+
@lru_cache(maxsize=None)
156+
def eval_fields(dataclass: Type[TDataClass]) -> Type[TDataClass]:
157+
"""Evaluate field types of a dataclass."""
158+
types = get_type_hints(dataclass, include_extras=True)
159+
160+
for field in fields(dataclass):
161+
field.type = types[field.name]
162+
163+
return dataclass
164+
165+
166+
@lru_cache(maxsize=None)
167+
def get_spec(field: AnyField) -> Optional[AnySpec]:
168+
"""Convert a dataclass field to a specification."""
169+
name = get_name(field.type, field.name)
170+
role = get_role(field.type)
171+
172+
if role is Role.DATA or role is Role.COORD:
173+
try:
174+
return ArraySpec(
175+
name=name,
176+
role=role.value,
177+
default=field.default,
178+
origin=get_dataclass(field.type),
179+
)
180+
except TypeError:
181+
return ArraySpec(
182+
name=name,
183+
role=role.value,
184+
default=field.default,
185+
dims=get_dims(field.type),
186+
type=get_dtype(field.type),
187+
)
188+
189+
if role is Role.ATTR or role is Role.NAME:
190+
return ScalarSpec(
191+
name=name,
192+
role=role.value,
193+
default=field.default,
194+
type=get_annotated(field.type),
195+
)

0 commit comments

Comments
 (0)