1
- __all__ = ["DataSpec " , "DataOptions " ]
1
+ __all__ = ["DataOptions " , "DataSpec " ]
2
2
3
3
4
4
# standard library
5
- from dataclasses import dataclass , field
5
+ from dataclasses import dataclass , field , fields
6
+ from functools import lru_cache
6
7
from typing import Any , Dict , Generic , Hashable , Optional , Type , TypeVar
7
8
8
9
9
10
# dependencies
10
- from typing_extensions import Literal , TypeAlias
11
+ from typing_extensions import Literal , TypeAlias , get_type_hints
11
12
12
13
13
14
# submodules
14
- from .typing import AnyDType , AnyXarray , DataClass , Dims
15
+ from .typing import (
16
+ AnyDType ,
17
+ AnyField ,
18
+ AnyXarray ,
19
+ DataClass ,
20
+ Dims ,
21
+ Role ,
22
+ get_annotated ,
23
+ get_dataclass ,
24
+ get_dims ,
25
+ get_dtype ,
26
+ get_name ,
27
+ get_role ,
28
+ )
15
29
16
30
17
31
# type hints
18
32
AnySpec : TypeAlias = "ArraySpec | ScalarSpec"
33
+ TDataClass = TypeVar ("TDataClass" , bound = DataClass [...])
19
34
TReturn = TypeVar ("TReturn" , AnyXarray , None )
20
35
21
36
@@ -33,14 +48,31 @@ class ArraySpec:
33
48
default : Any
34
49
"""Default value of the array."""
35
50
36
- dims : Dims
51
+ dims : Dims = ()
37
52
"""Dimensions of the array."""
38
53
39
- type : Optional [AnyDType ]
54
+ type : Optional [AnyDType ] = None
40
55
"""Data type of the array."""
41
56
42
57
origin : Optional [Type [DataClass [Any ]]] = None
43
- """Dataclass of dims and type origins."""
58
+ """Dataclass as origins of name, dims, and type."""
59
+
60
+ def __post_init__ (self ) -> None :
61
+ """Update name, dims, and type if origin exists."""
62
+ if self .origin is None :
63
+ return
64
+
65
+ dataspec = DataSpec .from_dataclass (self .origin )
66
+ setattr = object .__setattr__
67
+
68
+ for spec in dataspec .specs .of_data .values ():
69
+ setattr (self , "dims" , spec .dims )
70
+ setattr (self , "type" , spec .type )
71
+ break
72
+
73
+ for spec in dataspec .specs .of_name .values ():
74
+ setattr (self , "name" , spec .default )
75
+ break
44
76
45
77
46
78
@dataclass (frozen = True )
@@ -101,3 +133,63 @@ class DataSpec:
101
133
102
134
options : DataOptions [Any ] = DataOptions (type (None ))
103
135
"""Options for xarray data creation."""
136
+
137
+ @classmethod
138
+ def from_dataclass (cls , dataclass : Type [DataClass [...]]) -> "DataSpec" :
139
+ """Create a data specification from a dataclass."""
140
+ specs = SpecDict ()
141
+
142
+ for field in fields (eval_fields (dataclass )):
143
+ spec = get_spec (field )
144
+
145
+ if spec is not None :
146
+ specs [field .name ] = spec
147
+
148
+ try :
149
+ return cls (specs , dataclass .__dataoptions__ ) # type: ignore
150
+ except AttributeError :
151
+ return cls (specs )
152
+
153
+
154
+ # runtime functions
155
+ @lru_cache (maxsize = None )
156
+ def eval_fields (dataclass : Type [TDataClass ]) -> Type [TDataClass ]:
157
+ """Evaluate field types of a dataclass."""
158
+ types = get_type_hints (dataclass , include_extras = True )
159
+
160
+ for field in fields (dataclass ):
161
+ field .type = types [field .name ]
162
+
163
+ return dataclass
164
+
165
+
166
+ @lru_cache (maxsize = None )
167
+ def get_spec (field : AnyField ) -> Optional [AnySpec ]:
168
+ """Convert a dataclass field to a specification."""
169
+ name = get_name (field .type , field .name )
170
+ role = get_role (field .type )
171
+
172
+ if role is Role .DATA or role is Role .COORD :
173
+ try :
174
+ return ArraySpec (
175
+ name = name ,
176
+ role = role .value ,
177
+ default = field .default ,
178
+ origin = get_dataclass (field .type ),
179
+ )
180
+ except TypeError :
181
+ return ArraySpec (
182
+ name = name ,
183
+ role = role .value ,
184
+ default = field .default ,
185
+ dims = get_dims (field .type ),
186
+ type = get_dtype (field .type ),
187
+ )
188
+
189
+ if role is Role .ATTR or role is Role .NAME :
190
+ return ScalarSpec (
191
+ name = name ,
192
+ role = role .value ,
193
+ default = field .default ,
194
+ type = get_annotated (field .type ),
195
+ )
0 commit comments