Skip to content

Commit b6eefff

Browse files
committed
#137 Add runtime codes for DataModel
1 parent 2bac4a5 commit b6eefff

File tree

2 files changed

+116
-41
lines changed

2 files changed

+116
-41
lines changed

xarray_dataclasses/datamodel.py

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
# standard library
5-
from dataclasses import dataclass, field, is_dataclass
5+
from dataclasses import Field, dataclass, field, is_dataclass
66
from typing import Any, Dict, Hashable, List, Optional, Tuple, Type, Union, cast
77

88

@@ -13,7 +13,18 @@
1313

1414

1515
# submodules
16-
from .typing import ArrayLike, DataClass, DataType, Dims, Dtype
16+
from .typing import (
17+
ArrayLike,
18+
DataClass,
19+
DataType,
20+
Dims,
21+
Dtype,
22+
FieldType,
23+
get_dims,
24+
get_dtype,
25+
get_field_type,
26+
get_repr_type,
27+
)
1728

1829

1930
# type hints
@@ -63,7 +74,10 @@ class AttrEntry:
6374

6475
def __call__(self) -> Any:
6576
"""Create an object according to the entry."""
66-
...
77+
if self.value is MISSING:
78+
raise ValueError("Value is missing.")
79+
80+
return self.value
6781

6882

6983
@dataclass(frozen=True)
@@ -82,7 +96,7 @@ class DataEntry:
8296
dtype: Dtype = cast(Dtype, None)
8397
"""Data type of the DataArray that the data is cast to."""
8498

85-
base: Optional[Type[Any]] = None
99+
base: Optional[Type[DataClass[Any]]] = None
86100
"""Base dataclass that converts the data to a DataArray."""
87101

88102
value: Any = MISSING
@@ -91,9 +105,34 @@ class DataEntry:
91105
cast: bool = True
92106
"""Whether the value is cast to the data type."""
93107

108+
def __post_init__(self) -> None:
109+
"""Update the entry if a base dataclass exists."""
110+
if self.base is None:
111+
return
112+
113+
model = DataModel.from_dataclass(self.base)
114+
115+
setattr = object.__setattr__
116+
setattr(self, "dims", model.data_vars[0].dims)
117+
setattr(self, "dtype", model.data_vars[0].dtype)
118+
119+
if model.names:
120+
setattr(self, "name", model.names[0].value)
121+
94122
def __call__(self, reference: Optional[DataType] = None) -> xr.DataArray:
95123
"""Create a DataArray object according to the entry."""
96-
...
124+
from .dataarray import asdataarray
125+
126+
if self.value is MISSING:
127+
raise ValueError("Value is missing.")
128+
129+
if self.base is None:
130+
return get_typedarray(self.value, self.dims, self.dtype, reference)
131+
132+
if is_dataclass(self.value):
133+
return asdataarray(self.value, reference)
134+
else:
135+
return asdataarray(self.base(self.value), reference)
97136

98137

99138
@dataclass(frozen=True)
@@ -106,32 +145,42 @@ class DataModel:
106145
@property
107146
def attrs(self) -> List[AttrEntry]:
108147
"""Return a list of attribute entries."""
109-
...
148+
return [v for v in self.entries.values() if v.tag == "attr"]
110149

111150
@property
112151
def coords(self) -> List[DataEntry]:
113152
"""Return a list of coordinate entries."""
114-
...
153+
return [v for v in self.entries.values() if v.tag == "coord"]
115154

116155
@property
117156
def data_vars(self) -> List[DataEntry]:
118157
"""Return a list of data variable entries."""
119-
...
158+
return [v for v in self.entries.values() if v.tag == "data"]
120159

121160
@property
122161
def data_vars_items(self) -> List[Tuple[str, DataEntry]]:
123162
"""Return a list of data variable entries with keys."""
124-
...
163+
return [(k, v) for k, v in self.entries.items() if v.tag == "data"]
125164

126165
@property
127166
def names(self) -> List[AttrEntry]:
128167
"""Return a list of name entries."""
129-
...
168+
return [v for v in self.entries.values() if v.tag == "name"]
130169

131170
@classmethod
132171
def from_dataclass(cls, dataclass: AnyDataClass[P]) -> "DataModel":
133172
"""Create a data model from a dataclass or its object."""
134-
...
173+
model = cls()
174+
eval_dataclass(dataclass)
175+
176+
for field in dataclass.__dataclass_fields__.values():
177+
try:
178+
value = getattr(dataclass, field.name, MISSING)
179+
model.entries[field.name] = get_entry(field, value)
180+
except TypeError:
181+
pass
182+
183+
return model
135184

136185

137186
# runtime functions
@@ -156,7 +205,38 @@ def eval_dataclass(dataclass: AnyDataClass[P]) -> None:
156205
field.type = types[field.name]
157206

158207

159-
def typedarray(
208+
def get_entry(field: Field[Any], value: Any) -> AnyEntry:
209+
"""Create an entry from a field and its value."""
210+
field_type = get_field_type(field.type)
211+
repr_type = get_repr_type(field.type)
212+
213+
if field_type is FieldType.ATTR or field_type is FieldType.NAME:
214+
return AttrEntry(
215+
name=field.name,
216+
tag=field_type.value,
217+
value=value,
218+
type=repr_type,
219+
)
220+
221+
# hereafter field type is either COORD or DATA
222+
if is_dataclass(repr_type):
223+
return DataEntry(
224+
name=field.name,
225+
tag=field_type.value,
226+
base=repr_type,
227+
value=value,
228+
)
229+
else:
230+
return DataEntry(
231+
name=field.name,
232+
tag=field_type.value,
233+
dims=get_dims(repr_type),
234+
dtype=get_dtype(repr_type),
235+
value=value,
236+
)
237+
238+
239+
def get_typedarray(
160240
data: Any,
161241
dims: Dims,
162242
dtype: Dtype,

xarray_dataclasses/typing.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -48,33 +48,6 @@
4848
)
4949

5050

51-
# constants
52-
class FieldType(Enum):
53-
"""Annotation of xarray-related field hints."""
54-
55-
ATTR = "attr"
56-
"""Annotation of attribute field hints."""
57-
58-
COORD = "coord"
59-
"""Annotation of coordinate field hints."""
60-
61-
COORDOF = "coordof"
62-
"""Annotation of coordinate field hints."""
63-
64-
DATA = "data"
65-
"""Annotation of data (variable) field hints."""
66-
67-
DATAOF = "dataof"
68-
"""Annotation of data (variable) field hints."""
69-
70-
NAME = "name"
71-
"""Annotation of name field hints."""
72-
73-
def annotates(self, hint: Any) -> bool:
74-
"""Check if a field hint is annotated."""
75-
return self in get_args(hint)[1:]
76-
77-
7851
# type hints
7952
P = ParamSpec("P")
8053
T = TypeVar("T")
@@ -120,6 +93,28 @@ def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
12093
__dataclass_fields__: ClassVar[DataClassFields]
12194

12295

96+
# constants
97+
class FieldType(Enum):
98+
"""Annotation of xarray-related field hints."""
99+
100+
ATTR = "attr"
101+
"""Annotation of attribute field hints."""
102+
103+
COORD = "coord"
104+
"""Annotation of coordinate field hints."""
105+
106+
DATA = "data"
107+
"""Annotation of data (variable) field hints."""
108+
109+
NAME = "name"
110+
"""Annotation of name field hints."""
111+
112+
def annotates(self, hint: Any) -> bool:
113+
"""Check if a field hint is annotated."""
114+
return self in get_args(hint)[1:]
115+
116+
117+
# public type hints
123118
Attr = Annotated[T, FieldType.ATTR]
124119
"""Type hint to define attribute fields (``Attr[T]``).
125120
@@ -162,7 +157,7 @@ class Image(AsDataArray):
162157
163158
"""
164159

165-
Coordof = Annotated[Union[TDataClass, Any], FieldType.COORDOF]
160+
Coordof = Annotated[Union[TDataClass, Any], FieldType.COORD]
166161
"""Type hint to define coordinate fields (``Coordof[TDataClass]``).
167162
168163
Unlike ``Coord``, it specifies a dataclass that defines a DataArray class.
@@ -215,7 +210,7 @@ class ColorImage(AsDataset):
215210
216211
"""
217212

218-
Dataof = Annotated[Union[TDataClass, Any], FieldType.DATAOF]
213+
Dataof = Annotated[Union[TDataClass, Any], FieldType.DATA]
219214
"""Type hint to define data fields (``Coordof[TDataClass]``).
220215
221216
Unlike ``Data``, it specifies a dataclass that defines a DataArray class.

0 commit comments

Comments
 (0)