Skip to content

Commit 771384c

Browse files
committed
#174 Update copied specs module for xarray
1 parent e687e34 commit 771384c

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

xarray_dataclasses/v2/specs.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,28 @@
22

33

44
# standard library
5-
from dataclasses import dataclass, replace
5+
from dataclasses import dataclass, is_dataclass, replace
66
from dataclasses import Field as Field_, fields as fields_
77
from functools import lru_cache
8-
from typing import Any, Callable, Hashable, List, Optional, Type
8+
from typing import Any, Callable, Hashable, List, Optional, Tuple, Type
99

1010

1111
# dependencies
1212
from typing_extensions import Literal, get_type_hints
1313

1414

1515
# submodules
16-
from .typing import P, DataClass, Pandas, Role, get_dtype, get_name, get_role
16+
from .typing import (
17+
P,
18+
DataClass,
19+
Role,
20+
Xarray,
21+
get_annotated,
22+
get_dims,
23+
get_dtype,
24+
get_name,
25+
get_role,
26+
)
1727

1828

1929
# runtime classes
@@ -27,17 +37,33 @@ class Field:
2737
name: Hashable
2838
"""Name of the field."""
2939

30-
role: Literal["attr", "column", "data", "index"]
40+
role: Literal["attr", "coord", "data"]
3141
"""Role of the field."""
3242

33-
type: Optional[Any]
43+
default: Any
44+
"""Default value of the field data."""
45+
46+
type: Optional[Any] = None
3447
"""Type (hint) of the field data."""
3548

36-
dtype: Optional[str]
49+
dims: Optional[Tuple[str, ...]] = None
50+
"""Dimensions of the field data."""
51+
52+
dtype: Optional[str] = None
3753
"""Data type of the field data."""
3854

39-
default: Any
40-
"""Default value of the field data."""
55+
def __post_init__(self) -> None:
56+
"""Post updates for coordinate and data fields."""
57+
if not (self.role == "coord" or self.role == "data"):
58+
return None
59+
60+
if is_dataclass(self.type):
61+
spec = Spec.from_dataclass(self.type) # type: ignore
62+
field = spec.fields.of_data[0]
63+
object.__setattr__(self, "dims", field.dims)
64+
object.__setattr__(self, "dtype", field.dtype)
65+
else:
66+
object.__setattr__(self, "type", None)
4167

4268
def update(self, obj: DataClass[P]) -> "Field":
4369
"""Update the specification by a dataclass object."""
@@ -57,34 +83,29 @@ def of_attr(self) -> "Fields":
5783
return Fields(field for field in self if field.role == "attr")
5884

5985
@property
60-
def of_column(self) -> "Fields":
61-
"""Select only column field specifications."""
62-
return Fields(field for field in self if field.role == "column")
86+
def of_coord(self) -> "Fields":
87+
"""Select only coordinate field specifications."""
88+
return Fields(field for field in self if field.role == "coord")
6389

6490
@property
6591
def of_data(self) -> "Fields":
6692
"""Select only data field specifications."""
6793
return Fields(field for field in self if field.role == "data")
6894

69-
@property
70-
def of_index(self) -> "Fields":
71-
"""Select only index field specifications."""
72-
return Fields(field for field in self if field.role == "index")
73-
7495
def update(self, obj: DataClass[P]) -> "Fields":
7596
"""Update the specifications by a dataclass object."""
7697
return Fields(field.update(obj) for field in self)
7798

7899

79100
@dataclass(frozen=True)
80101
class Spec:
81-
"""Specification of a pandas dataclass."""
102+
"""Specification of a xarray dataclass."""
82103

83104
fields: Fields
84105
"""List of field specifications."""
85106

86-
factory: Optional[Callable[..., Pandas]] = None
87-
"""Factory for pandas data creation."""
107+
factory: Optional[Callable[..., Xarray]] = None
108+
"""Factory for xarray data creation."""
88109

89110
@classmethod
90111
def from_dataclass(cls, dataclass: Type[DataClass[P]]) -> "Spec":
@@ -97,7 +118,7 @@ def from_dataclass(cls, dataclass: Type[DataClass[P]]) -> "Spec":
97118
if field is not None:
98119
fields.append(field)
99120

100-
factory = getattr(dataclass, "__pandas_factory__", None)
121+
factory = getattr(dataclass, "__xarray_factory__", None)
101122
return cls(fields, factory)
102123

103124
def update(self, obj: DataClass[P]) -> "Spec":
@@ -122,9 +143,10 @@ def convert_field(field_: "Field_[Any]") -> Optional[Field]:
122143
id=field_.name,
123144
name=get_name(field_.type, field_.name),
124145
role=role.name.lower(), # type: ignore
125-
type=field_.type,
126-
dtype=get_dtype(field_.type),
127146
default=field_.default,
147+
type=get_annotated(field_.type),
148+
dims=get_dims(field_.type),
149+
dtype=get_dtype(field_.type),
128150
)
129151

130152

0 commit comments

Comments
 (0)