2
2
3
3
4
4
# standard library
5
- from dataclasses import dataclass , field , is_dataclass
5
+ from dataclasses import Field , dataclass , field , is_dataclass
6
6
from typing import Any , Dict , Hashable , List , Optional , Tuple , Type , Union , cast
7
7
8
8
13
13
14
14
15
15
# submodules
16
- from .typing import ArrayLike , DataClass , DataType , Dims , Dtype
16
+ from .typing import (
17
+ ArrayLike ,
18
+ DataClass ,
19
+ DataType ,
20
+ Dims ,
21
+ Dtype ,
22
+ FieldType ,
23
+ get_dims ,
24
+ get_dtype ,
25
+ get_field_type ,
26
+ get_repr_type ,
27
+ )
17
28
18
29
19
30
# type hints
@@ -63,7 +74,10 @@ class AttrEntry:
63
74
64
75
def __call__ (self ) -> Any :
65
76
"""Create an object according to the entry."""
66
- ...
77
+ if self .value is MISSING :
78
+ raise ValueError ("Value is missing." )
79
+
80
+ return self .value
67
81
68
82
69
83
@dataclass (frozen = True )
@@ -82,7 +96,7 @@ class DataEntry:
82
96
dtype : Dtype = cast (Dtype , None )
83
97
"""Data type of the DataArray that the data is cast to."""
84
98
85
- base : Optional [Type [Any ]] = None
99
+ base : Optional [Type [DataClass [ Any ] ]] = None
86
100
"""Base dataclass that converts the data to a DataArray."""
87
101
88
102
value : Any = MISSING
@@ -91,9 +105,34 @@ class DataEntry:
91
105
cast : bool = True
92
106
"""Whether the value is cast to the data type."""
93
107
108
+ def __post_init__ (self ) -> None :
109
+ """Update the entry if a base dataclass exists."""
110
+ if self .base is None :
111
+ return
112
+
113
+ model = DataModel .from_dataclass (self .base )
114
+
115
+ setattr = object .__setattr__
116
+ setattr (self , "dims" , model .data_vars [0 ].dims )
117
+ setattr (self , "dtype" , model .data_vars [0 ].dtype )
118
+
119
+ if model .names :
120
+ setattr (self , "name" , model .names [0 ].value )
121
+
94
122
def __call__ (self , reference : Optional [DataType ] = None ) -> xr .DataArray :
95
123
"""Create a DataArray object according to the entry."""
96
- ...
124
+ from .dataarray import asdataarray
125
+
126
+ if self .value is MISSING :
127
+ raise ValueError ("Value is missing." )
128
+
129
+ if self .base is None :
130
+ return get_typedarray (self .value , self .dims , self .dtype , reference )
131
+
132
+ if is_dataclass (self .value ):
133
+ return asdataarray (self .value , reference )
134
+ else :
135
+ return asdataarray (self .base (self .value ), reference )
97
136
98
137
99
138
@dataclass (frozen = True )
@@ -106,32 +145,42 @@ class DataModel:
106
145
@property
107
146
def attrs (self ) -> List [AttrEntry ]:
108
147
"""Return a list of attribute entries."""
109
- ...
148
+ return [ v for v in self . entries . values () if v . tag == "attr" ]
110
149
111
150
@property
112
151
def coords (self ) -> List [DataEntry ]:
113
152
"""Return a list of coordinate entries."""
114
- ...
153
+ return [ v for v in self . entries . values () if v . tag == "coord" ]
115
154
116
155
@property
117
156
def data_vars (self ) -> List [DataEntry ]:
118
157
"""Return a list of data variable entries."""
119
- ...
158
+ return [ v for v in self . entries . values () if v . tag == "data" ]
120
159
121
160
@property
122
161
def data_vars_items (self ) -> List [Tuple [str , DataEntry ]]:
123
162
"""Return a list of data variable entries with keys."""
124
- ...
163
+ return [( k , v ) for k , v in self . entries . items () if v . tag == "data" ]
125
164
126
165
@property
127
166
def names (self ) -> List [AttrEntry ]:
128
167
"""Return a list of name entries."""
129
- ...
168
+ return [ v for v in self . entries . values () if v . tag == "name" ]
130
169
131
170
@classmethod
132
171
def from_dataclass (cls , dataclass : AnyDataClass [P ]) -> "DataModel" :
133
172
"""Create a data model from a dataclass or its object."""
134
- ...
173
+ model = cls ()
174
+ eval_dataclass (dataclass )
175
+
176
+ for field in dataclass .__dataclass_fields__ .values ():
177
+ try :
178
+ value = getattr (dataclass , field .name , MISSING )
179
+ model .entries [field .name ] = get_entry (field , value )
180
+ except TypeError :
181
+ pass
182
+
183
+ return model
135
184
136
185
137
186
# runtime functions
@@ -156,7 +205,38 @@ def eval_dataclass(dataclass: AnyDataClass[P]) -> None:
156
205
field .type = types [field .name ]
157
206
158
207
159
- def typedarray (
208
+ def get_entry (field : Field [Any ], value : Any ) -> AnyEntry :
209
+ """Create an entry from a field and its value."""
210
+ field_type = get_field_type (field .type )
211
+ repr_type = get_repr_type (field .type )
212
+
213
+ if field_type is FieldType .ATTR or field_type is FieldType .NAME :
214
+ return AttrEntry (
215
+ name = field .name ,
216
+ tag = field_type .value ,
217
+ value = value ,
218
+ type = repr_type ,
219
+ )
220
+
221
+ # hereafter field type is either COORD or DATA
222
+ if is_dataclass (repr_type ):
223
+ return DataEntry (
224
+ name = field .name ,
225
+ tag = field_type .value ,
226
+ base = repr_type ,
227
+ value = value ,
228
+ )
229
+ else :
230
+ return DataEntry (
231
+ name = field .name ,
232
+ tag = field_type .value ,
233
+ dims = get_dims (repr_type ),
234
+ dtype = get_dtype (repr_type ),
235
+ value = value ,
236
+ )
237
+
238
+
239
+ def get_typedarray (
160
240
data : Any ,
161
241
dims : Dims ,
162
242
dtype : Dtype ,
0 commit comments