2
2
3
3
4
4
# standard library
5
- from dataclasses import dataclass , replace
5
+ from dataclasses import dataclass , is_dataclass , replace
6
6
from dataclasses import Field as Field_ , fields as fields_
7
7
from functools import lru_cache
8
- from typing import Any , Callable , Hashable , List , Optional , Type
8
+ from typing import Any , Callable , Hashable , List , Optional , Tuple , Type
9
9
10
10
11
11
# dependencies
12
12
from typing_extensions import Literal , get_type_hints
13
13
14
14
15
15
# submodules
16
- from .typing import P , DataClass , Pandas , Role , get_dtype , get_name , get_role
16
+ from .typing import (
17
+ P ,
18
+ DataClass ,
19
+ Role ,
20
+ Xarray ,
21
+ get_annotated ,
22
+ get_dims ,
23
+ get_dtype ,
24
+ get_name ,
25
+ get_role ,
26
+ )
17
27
18
28
19
29
# runtime classes
@@ -27,17 +37,33 @@ class Field:
27
37
name : Hashable
28
38
"""Name of the field."""
29
39
30
- role : Literal ["attr" , "column " , "data" , "index " ]
40
+ role : Literal ["attr" , "coord " , "data" ]
31
41
"""Role of the field."""
32
42
33
- type : Optional [Any ]
43
+ default : Any
44
+ """Default value of the field data."""
45
+
46
+ type : Optional [Any ] = None
34
47
"""Type (hint) of the field data."""
35
48
36
- dtype : Optional [str ]
49
+ dims : Optional [Tuple [str , ...]] = None
50
+ """Dimensions of the field data."""
51
+
52
+ dtype : Optional [str ] = None
37
53
"""Data type of the field data."""
38
54
39
- default : Any
40
- """Default value of the field data."""
55
+ def __post_init__ (self ) -> None :
56
+ """Post updates for coordinate and data fields."""
57
+ if not (self .role == "coord" or self .role == "data" ):
58
+ return None
59
+
60
+ if is_dataclass (self .type ):
61
+ spec = Spec .from_dataclass (self .type ) # type: ignore
62
+ field = spec .fields .of_data [0 ]
63
+ object .__setattr__ (self , "dims" , field .dims )
64
+ object .__setattr__ (self , "dtype" , field .dtype )
65
+ else :
66
+ object .__setattr__ (self , "type" , None )
41
67
42
68
def update (self , obj : DataClass [P ]) -> "Field" :
43
69
"""Update the specification by a dataclass object."""
@@ -57,34 +83,29 @@ def of_attr(self) -> "Fields":
57
83
return Fields (field for field in self if field .role == "attr" )
58
84
59
85
@property
60
- def of_column (self ) -> "Fields" :
61
- """Select only column field specifications."""
62
- return Fields (field for field in self if field .role == "column " )
86
+ def of_coord (self ) -> "Fields" :
87
+ """Select only coordinate field specifications."""
88
+ return Fields (field for field in self if field .role == "coord " )
63
89
64
90
@property
65
91
def of_data (self ) -> "Fields" :
66
92
"""Select only data field specifications."""
67
93
return Fields (field for field in self if field .role == "data" )
68
94
69
- @property
70
- def of_index (self ) -> "Fields" :
71
- """Select only index field specifications."""
72
- return Fields (field for field in self if field .role == "index" )
73
-
74
95
def update (self , obj : DataClass [P ]) -> "Fields" :
75
96
"""Update the specifications by a dataclass object."""
76
97
return Fields (field .update (obj ) for field in self )
77
98
78
99
79
100
@dataclass (frozen = True )
80
101
class Spec :
81
- """Specification of a pandas dataclass."""
102
+ """Specification of a xarray dataclass."""
82
103
83
104
fields : Fields
84
105
"""List of field specifications."""
85
106
86
- factory : Optional [Callable [..., Pandas ]] = None
87
- """Factory for pandas data creation."""
107
+ factory : Optional [Callable [..., Xarray ]] = None
108
+ """Factory for xarray data creation."""
88
109
89
110
@classmethod
90
111
def from_dataclass (cls , dataclass : Type [DataClass [P ]]) -> "Spec" :
@@ -97,7 +118,7 @@ def from_dataclass(cls, dataclass: Type[DataClass[P]]) -> "Spec":
97
118
if field is not None :
98
119
fields .append (field )
99
120
100
- factory = getattr (dataclass , "__pandas_factory__ " , None )
121
+ factory = getattr (dataclass , "__xarray_factory__ " , None )
101
122
return cls (fields , factory )
102
123
103
124
def update (self , obj : DataClass [P ]) -> "Spec" :
@@ -122,9 +143,10 @@ def convert_field(field_: "Field_[Any]") -> Optional[Field]:
122
143
id = field_ .name ,
123
144
name = get_name (field_ .type , field_ .name ),
124
145
role = role .name .lower (), # type: ignore
125
- type = field_ .type ,
126
- dtype = get_dtype (field_ .type ),
127
146
default = field_ .default ,
147
+ type = get_annotated (field_ .type ),
148
+ dims = get_dims (field_ .type ),
149
+ dtype = get_dtype (field_ .type ),
128
150
)
129
151
130
152
0 commit comments