Skip to content

Commit 32392f6

Browse files
authored
#163 Merge pull request from astropenguin/astropenguin/issue156
Add specs module
2 parents 214ab6b + bc37fa8 commit 32392f6

File tree

6 files changed

+440
-57
lines changed

6 files changed

+440
-57
lines changed

tests/test_specs.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# standard library
2+
from dataclasses import MISSING, dataclass
3+
from typing import Tuple
4+
5+
6+
# dependencies
7+
import numpy as np
8+
import xarray as xr
9+
from typing_extensions import Annotated as Ann
10+
from typing_extensions import Literal as L
11+
from xarray_dataclasses.specs import DataOptions, DataSpec
12+
from xarray_dataclasses.typing import Attr, Coordof, Data, Name
13+
14+
15+
# type hints
16+
DataDims = Tuple[L["lon"], L["lat"], L["time"]]
17+
18+
19+
# test datasets
20+
@dataclass
21+
class Lon:
22+
"""Specification of relative longitude."""
23+
24+
data: Data[L["lon"], float]
25+
units: Attr[str] = "deg"
26+
name: Name[str] = "Relative longitude"
27+
28+
29+
@dataclass
30+
class Lat:
31+
"""Specification of relative latitude."""
32+
33+
data: Data[L["lat"], float]
34+
units: Attr[str] = "m"
35+
name: Name[str] = "Relative latitude"
36+
37+
38+
@dataclass
39+
class Time:
40+
"""Specification of time."""
41+
42+
data: Data[L["time"], L["datetime64[ns]"]]
43+
name: Name[str] = "Time in UTC"
44+
45+
46+
@dataclass
47+
class Weather:
48+
"""Time-series spatial weather information at a location."""
49+
50+
temperature: Ann[Data[DataDims, float], "Temperature"]
51+
humidity: Ann[Data[DataDims, float], "Humidity"]
52+
wind_speed: Ann[Data[DataDims, float], "Wind speed"]
53+
wind_direction: Ann[Data[DataDims, float], "Wind direction"]
54+
lon: Coordof[Lon]
55+
lat: Coordof[Lat]
56+
time: Coordof[Time]
57+
location: Attr[str] = "Tokyo"
58+
longitude: Attr[float] = 139.69167
59+
latitude: Attr[float] = 35.68944
60+
name: Name[str] = "weather"
61+
62+
63+
# test functions
64+
def test_temperature() -> None:
65+
spec = DataSpec.from_dataclass(Weather).specs.of_data["temperature"]
66+
67+
assert spec.name == "Temperature"
68+
assert spec.role == "data"
69+
assert spec.dims == ("lon", "lat", "time")
70+
assert spec.dtype == np.dtype("f8")
71+
assert spec.default is MISSING
72+
assert spec.origin is None
73+
74+
75+
def test_humidity() -> None:
76+
spec = DataSpec.from_dataclass(Weather).specs.of_data["humidity"]
77+
78+
assert spec.name == "Humidity"
79+
assert spec.role == "data"
80+
assert spec.dims == ("lon", "lat", "time")
81+
assert spec.dtype == np.dtype("f8")
82+
assert spec.default is MISSING
83+
assert spec.origin is None
84+
85+
86+
def test_wind_speed() -> None:
87+
spec = DataSpec.from_dataclass(Weather).specs.of_data["wind_speed"]
88+
89+
assert spec.name == "Wind speed"
90+
assert spec.role == "data"
91+
assert spec.dims == ("lon", "lat", "time")
92+
assert spec.dtype == np.dtype("f8")
93+
assert spec.default is MISSING
94+
assert spec.origin is None
95+
96+
97+
def test_wind_direction() -> None:
98+
spec = DataSpec.from_dataclass(Weather).specs.of_data["wind_direction"]
99+
100+
assert spec.name == "Wind direction"
101+
assert spec.role == "data"
102+
assert spec.dims == ("lon", "lat", "time")
103+
assert spec.dtype == np.dtype("f8")
104+
assert spec.default is MISSING
105+
assert spec.origin is None
106+
107+
108+
def test_lon() -> None:
109+
spec = DataSpec.from_dataclass(Weather).specs.of_coord["lon"]
110+
111+
assert spec.name == "Relative longitude"
112+
assert spec.role == "coord"
113+
assert spec.dims == ("lon",)
114+
assert spec.dtype == np.dtype("f8")
115+
assert spec.default is MISSING
116+
assert spec.origin is Lon
117+
118+
119+
def test_lat() -> None:
120+
spec = DataSpec.from_dataclass(Weather).specs.of_coord["lat"]
121+
122+
assert spec.name == "Relative latitude"
123+
assert spec.role == "coord"
124+
assert spec.dims == ("lat",)
125+
assert spec.dtype == np.dtype("f8")
126+
assert spec.default is MISSING
127+
assert spec.origin is Lat
128+
129+
130+
def test_time() -> None:
131+
spec = DataSpec.from_dataclass(Weather).specs.of_coord["time"]
132+
133+
assert spec.name == "Time in UTC"
134+
assert spec.role == "coord"
135+
assert spec.dims == ("time",)
136+
assert spec.dtype == np.dtype("M8[ns]")
137+
assert spec.default is MISSING
138+
assert spec.origin is Time
139+
140+
141+
def test_location() -> None:
142+
spec = DataSpec.from_dataclass(Weather).specs.of_attr["location"]
143+
144+
assert spec.name == "location"
145+
assert spec.role == "attr"
146+
assert spec.type is str
147+
assert spec.default == "Tokyo"
148+
149+
150+
def test_longitude() -> None:
151+
spec = DataSpec.from_dataclass(Weather).specs.of_attr["longitude"]
152+
153+
assert spec.name == "longitude"
154+
assert spec.role == "attr"
155+
assert spec.type is float
156+
assert spec.default == 139.69167
157+
158+
159+
def test_latitude() -> None:
160+
spec = DataSpec.from_dataclass(Weather).specs.of_attr["latitude"]
161+
162+
assert spec.name == "latitude"
163+
assert spec.role == "attr"
164+
assert spec.type is float
165+
assert spec.default == 35.68944
166+
167+
168+
def test_name() -> None:
169+
spec = DataSpec.from_dataclass(Weather).specs.of_name["name"]
170+
171+
assert spec.name == "name"
172+
assert spec.role == "name"
173+
assert spec.type is str
174+
assert spec.default == "weather"
175+
176+
177+
def test_dataoptions() -> None:
178+
options = DataOptions(xr.DataArray)
179+
180+
assert DataSpec().options.factory is type(None)
181+
assert DataSpec(options=options).options.factory is xr.DataArray

tests/test_typing.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
Attr,
1515
Coord,
1616
Data,
17-
FType,
1817
Name,
18+
Role,
1919
get_dims,
2020
get_dtype,
21-
get_ftype,
2221
get_name,
22+
get_role,
2323
)
2424

2525

@@ -54,24 +54,6 @@
5454
(Union[Ann[Data[Any, float], "data"], Ann[Any, "any"]], np.dtype("f8")),
5555
]
5656

57-
testdata_ftype = [
58-
(Attr[Any], FType.ATTR),
59-
(Data[Any, Any], FType.DATA),
60-
(Coord[Any, Any], FType.COORD),
61-
(Name[Any], FType.NAME),
62-
(Any, FType.OTHER),
63-
(Ann[Attr[Any], "attr"], FType.ATTR),
64-
(Ann[Data[Any, Any], "data"], FType.DATA),
65-
(Ann[Coord[Any, Any], "coord"], FType.COORD),
66-
(Ann[Name[Any], "name"], FType.NAME),
67-
(Ann[Any, "other"], FType.OTHER),
68-
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], FType.ATTR),
69-
(Union[Ann[Data[Any, Any], "data"], Ann[Any, "any"]], FType.DATA),
70-
(Union[Ann[Coord[Any, Any], "coord"], Ann[Any, "any"]], FType.COORD),
71-
(Union[Ann[Name[Any], "name"], Ann[Any, "any"]], FType.NAME),
72-
(Union[Ann[Any, "other"], Ann[Any, "any"]], FType.OTHER),
73-
]
74-
7557
testdata_name = [
7658
(Attr[Any], None),
7759
(Data[Any, Any], None),
@@ -90,6 +72,24 @@
9072
(Union[Ann[Any, "other"], Ann[Any, "any"]], None),
9173
]
9274

75+
testdata_role = [
76+
(Attr[Any], Role.ATTR),
77+
(Data[Any, Any], Role.DATA),
78+
(Coord[Any, Any], Role.COORD),
79+
(Name[Any], Role.NAME),
80+
(Any, Role.OTHER),
81+
(Ann[Attr[Any], "attr"], Role.ATTR),
82+
(Ann[Data[Any, Any], "data"], Role.DATA),
83+
(Ann[Coord[Any, Any], "coord"], Role.COORD),
84+
(Ann[Name[Any], "name"], Role.NAME),
85+
(Ann[Any, "other"], Role.OTHER),
86+
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], Role.ATTR),
87+
(Union[Ann[Data[Any, Any], "data"], Ann[Any, "any"]], Role.DATA),
88+
(Union[Ann[Coord[Any, Any], "coord"], Ann[Any, "any"]], Role.COORD),
89+
(Union[Ann[Name[Any], "name"], Ann[Any, "any"]], Role.NAME),
90+
(Union[Ann[Any, "other"], Ann[Any, "any"]], Role.OTHER),
91+
]
92+
9393

9494
# test functions
9595
@mark.parametrize("tp, dims", testdata_dims)
@@ -102,11 +102,11 @@ def test_get_dtype(tp: Any, dtype: Any) -> None:
102102
assert get_dtype(tp) == dtype
103103

104104

105-
@mark.parametrize("tp, ftype", testdata_ftype)
106-
def test_get_ftype(tp: Any, ftype: Any) -> None:
107-
assert get_ftype(tp) == ftype
108-
109-
110105
@mark.parametrize("tp, name", testdata_name)
111106
def test_get_name(tp: Any, name: Any) -> None:
112107
assert get_name(tp) == name
108+
109+
110+
@mark.parametrize("tp, role", testdata_role)
111+
def test_get_role(tp: Any, role: Any) -> None:
112+
assert get_role(tp) == role

xarray_dataclasses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .dataset import *
88
from .datamodel import *
99
from .dataoptions import *
10+
from .specs import *
1011
from .typing import *
1112

1213

xarray_dataclasses/datamodel.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717
from .typing import (
1818
AnyDType,
1919
AnyField,
20-
DataClass,
2120
AnyXarray,
21+
DataClass,
2222
Dims,
23-
FType,
23+
Role,
2424
get_annotated,
2525
get_dataclass,
2626
get_dims,
2727
get_dtype,
28-
get_ftype,
2928
get_name,
29+
get_role,
3030
)
3131

3232

@@ -209,29 +209,29 @@ def eval_dataclass(dataclass: AnyDataClass[PInit]) -> None:
209209

210210
def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]:
211211
"""Create an entry from a field and its value."""
212-
ftype = get_ftype(field.type)
212+
role = get_role(field.type)
213213
name = get_name(field.type, field.name)
214214

215-
if ftype is FType.ATTR or ftype is FType.NAME:
215+
if role is Role.ATTR or role is Role.NAME:
216216
return AttrEntry(
217217
name=name,
218-
tag=ftype.value,
218+
tag=role.value,
219219
value=value,
220220
type=get_annotated(field.type),
221221
)
222222

223-
if ftype is FType.COORD or ftype is FType.DATA:
223+
if role is Role.COORD or role is Role.DATA:
224224
try:
225225
return DataEntry(
226226
name=name,
227-
tag=ftype.value,
227+
tag=role.value,
228228
base=get_dataclass(field.type),
229229
value=value,
230230
)
231231
except TypeError:
232232
return DataEntry(
233233
name=name,
234-
tag=ftype.value,
234+
tag=role.value,
235235
dims=get_dims(field.type),
236236
dtype=get_dtype(field.type),
237237
value=value,

0 commit comments

Comments
 (0)