Skip to content

Commit b97fa4a

Browse files
authored
Merge pull request #138 from astropenguin/astropenguin/issue137
Update data model
2 parents a51be9e + 0438a65 commit b97fa4a

File tree

8 files changed

+538
-337
lines changed

8 files changed

+538
-337
lines changed

poetry.lock

Lines changed: 102 additions & 62 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@ documentation = "https://astropenguin.github.io/xarray-dataclasses/"
1212
[tool.poetry.dependencies]
1313
python = ">=3.7.1, <3.10"
1414
morecopy = "^0.2"
15+
more-itertools = "^8.12"
1516
numpy = "^1.19"
1617
typing-extensions = "^3.10"
1718
xarray = ">=0.18, <0.30"
1819

1920
[tool.poetry.dev-dependencies]
20-
black = "^21.12b"
21+
black = "^22.1"
2122
ipython = "^7.31"
2223
myst-parser = "^0.16"
23-
pydata-sphinx-theme = "^0.7"
24+
pydata-sphinx-theme = "^0.8"
2425
pytest = "^6.2"
25-
sphinx = "^4.3"
26+
sphinx = "^4.4"
2627

2728
[build-system]
2829
requires = ["poetry-core>=1.0.0"]

tests/test_datamodel.py

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -53,73 +53,100 @@ class ColorImage:
5353

5454
# test functions
5555
def test_xaxis_attr() -> None:
56-
item = next(iter(xaxis_model.attr.values()))
57-
assert item.name == "units"
58-
assert item.value == "pixel"
59-
assert item.type == "builtins.str"
56+
units = xaxis_model.attrs[0]
57+
assert units.name == "units"
58+
assert units.tag == "attr"
59+
assert units.type is str
60+
assert units.value == "pixel"
61+
assert units.cast == False
6062

6163

6264
def test_xaxis_data() -> None:
63-
item = next(iter(xaxis_model.data.values()))
64-
assert item.name == "data"
65-
assert item.type == {"dims": ("x",), "dtype": "int"}
66-
assert item.factory is None
65+
data = xaxis_model.data_vars[0]
66+
assert data.name == "data"
67+
assert data.tag == "data"
68+
assert data.dims == ("x",)
69+
assert data.dtype == "int"
70+
assert data.base is None
71+
assert data.cast == True
6772

6873

6974
def test_yaxis_attr() -> None:
70-
item = next(iter(yaxis_model.attr.values()))
71-
assert item.name == "units"
72-
assert item.value == "pixel"
73-
assert item.type == "builtins.str"
75+
units = yaxis_model.attrs[0]
76+
assert units.name == "units"
77+
assert units.tag == "attr"
78+
assert units.type is str
79+
assert units.value == "pixel"
80+
assert units.cast == False
7481

7582

7683
def test_yaxis_data() -> None:
77-
item = next(iter(yaxis_model.data.values()))
78-
assert item.name == "data"
79-
assert item.type == {"dims": ("y",), "dtype": "int"}
80-
assert item.factory is None
84+
data = yaxis_model.data_vars[0]
85+
assert data.name == "data"
86+
assert data.tag == "data"
87+
assert data.dims == ("y",)
88+
assert data.dtype == "int"
89+
assert data.base is None
90+
assert data.cast == True
8191

8292

8393
def test_image_coord() -> None:
84-
items = iter(image_model.coord.values())
85-
86-
item = next(items)
87-
assert item.name == "mask"
88-
assert item.type == {"dims": ("x", "y"), "dtype": "bool"}
89-
assert item.factory is None
90-
91-
item = next(items)
92-
assert item.name == "x"
93-
assert item.type == {"dims": ("x",), "dtype": "int"}
94-
assert item.factory is XAxis
95-
96-
item = next(items)
97-
assert item.name == "y"
98-
assert item.type == {"dims": ("y",), "dtype": "int"}
99-
assert item.factory is YAxis
94+
mask = image_model.coords[0]
95+
assert mask.name == "mask"
96+
assert mask.tag == "coord"
97+
assert mask.dims == ("x", "y")
98+
assert mask.dtype == "bool"
99+
assert mask.base is None
100+
assert mask.cast == True
101+
102+
x = image_model.coords[1]
103+
assert x.name == "x"
104+
assert x.tag == "coord"
105+
assert x.dims == ("x",)
106+
assert x.dtype == "int"
107+
assert x.base is XAxis
108+
assert x.cast == True
109+
110+
y = image_model.coords[2]
111+
assert y.name == "y"
112+
assert y.tag == "coord"
113+
assert y.dims == ("y",)
114+
assert y.dtype == "int"
115+
assert y.base is YAxis
116+
assert y.cast == True
100117

101118

102119
def test_image_data() -> None:
103-
item = next(iter(image_model.data.values()))
104-
assert item.name == "data"
105-
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
106-
assert item.factory is None
120+
data = image_model.data_vars[0]
121+
assert data.name == "data"
122+
assert data.tag == "data"
123+
assert data.dims == ("x", "y")
124+
assert data.dtype == "float"
125+
assert data.base is None
126+
assert data.cast == True
107127

108128

109129
def test_color_data() -> None:
110-
items = iter(color_model.data.values())
111-
112-
item = next(items)
113-
assert item.name == "red"
114-
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
115-
assert item.factory is Image
116-
117-
item = next(items)
118-
assert item.name == "green"
119-
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
120-
assert item.factory is Image
121-
122-
item = next(items)
123-
assert item.name == "blue"
124-
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
125-
assert item.factory is Image
130+
red = color_model.data_vars[0]
131+
assert red.name == "red"
132+
assert red.tag == "data"
133+
assert red.dims == ("x", "y")
134+
assert red.dtype == "float"
135+
assert red.base is Image
136+
assert red.cast == True
137+
138+
green = color_model.data_vars[1]
139+
assert green.name == "green"
140+
assert green.tag == "data"
141+
assert green.dims == ("x", "y")
142+
assert green.dtype == "float"
143+
assert green.base is Image
144+
assert green.cast == True
145+
146+
blue = color_model.data_vars[2]
147+
assert blue.name == "blue"
148+
assert blue.tag == "data"
149+
assert blue.dims == ("x", "y")
150+
assert blue.dtype == "float"
151+
assert blue.base is Image
152+
assert blue.cast == True

tests/test_typing.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,87 @@
11
# standard library
2-
from typing import Any, Tuple
2+
from typing import Any, Optional, Tuple, Union
33

44

55
# third-party packages
66
from pytest import mark
7-
from typing_extensions import Literal
7+
from typing_extensions import Annotated, Literal
88

99

1010
# submodules
11-
from xarray_dataclasses.typing import Data, get_dims, get_dtype, unannotate
11+
from xarray_dataclasses.typing import (
12+
ArrayLike,
13+
Attr,
14+
Coord,
15+
Data,
16+
Name,
17+
get_dims,
18+
get_dtype,
19+
get_field_type,
20+
get_repr_type,
21+
)
1222

1323

1424
# type hints
15-
Int = Literal["int"]
25+
Int64 = Literal["int64"]
26+
NoneType = type(None)
1627
X = Literal["x"]
1728
Y = Literal["y"]
1829

1930

2031
# test datasets
2132
testdata_dims = [
22-
(Data[X, Any], ("x",)),
23-
(Data[Tuple[()], Any], ()),
24-
(Data[Tuple[X], Any], ("x",)),
25-
(Data[Tuple[X, Y], Any], ("x", "y")),
33+
(X, ("x",)),
34+
(Tuple[()], ()),
35+
(Tuple[X], ("x",)),
36+
(Tuple[X, Y], ("x", "y")),
37+
(ArrayLike[X, Any], ("x",)),
38+
(ArrayLike[Tuple[()], Any], ()),
39+
(ArrayLike[Tuple[X], Any], ("x",)),
40+
(ArrayLike[Tuple[X, Y], Any], ("x", "y")),
2641
]
2742

2843
testdata_dtype = [
29-
(Data[X, Any], None),
30-
(Data[X, None], None),
31-
(Data[X, int], "int"),
32-
(Data[X, Int], "int"),
44+
(Any, None),
45+
(NoneType, None),
46+
(Int64, "int64"),
47+
(int, "int"),
48+
(ArrayLike[Any, Any], None),
49+
(ArrayLike[Any, NoneType], None),
50+
(ArrayLike[Any, Int64], "int64"),
51+
(ArrayLike[Any, int], "int"),
52+
]
53+
54+
testdata_field_type = [
55+
(Attr[Any], "attr"),
56+
(Coord[Any, Any], "coord"),
57+
(Data[Any, Any], "data"),
58+
(Name[Any], "name"),
59+
]
60+
61+
testdata_repr_type = [
62+
(int, int),
63+
(Annotated[int, "annotation"], int),
64+
(Union[int, float], int),
65+
(Optional[int], int),
3366
]
3467

3568

3669
# test functions
37-
@mark.parametrize("hint, dims", testdata_dims)
38-
def test_get_dims(hint: Any, dims: Any) -> None:
39-
assert get_dims(unannotate(hint)) == dims
70+
@mark.parametrize("type_, dims", testdata_dims)
71+
def test_get_dims(type_: Any, dims: Any) -> None:
72+
assert get_dims(type_) == dims
73+
74+
75+
@mark.parametrize("type_, dtype", testdata_dtype)
76+
def test_get_dtype(type_: Any, dtype: Any) -> None:
77+
assert get_dtype(type_) == dtype
78+
79+
80+
@mark.parametrize("type_, field_type", testdata_field_type)
81+
def test_get_field_type(type_: Any, field_type: Any) -> None:
82+
assert get_field_type(type_).value == field_type
4083

4184

42-
@mark.parametrize("hint, dtype", testdata_dtype)
43-
def test_get_dtype(hint: Any, dtype: Any) -> None:
44-
assert get_dtype(unannotate(hint)) == dtype
85+
@mark.parametrize("type_, repr_type", testdata_repr_type)
86+
def test_get_repr_type(type_: Any, repr_type: Any) -> None:
87+
assert get_repr_type(type_) == repr_type

xarray_dataclasses/dataarray.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,21 @@ def asdataarray(
104104
dataoptions = DataOptions(xr.DataArray)
105105

106106
model = DataModel.from_dataclass(dataclass)
107-
item = next(iter(model.data.values()))
108-
dataarray = dataoptions.factory(item(reference))
107+
dataarray = dataoptions.factory(model.data_vars[0](reference))
109108

110-
for item in model.coord.values():
111-
if item.name in dataarray.dims:
112-
dataarray.coords[item.name] = item(dataarray)
109+
for entry in model.coords:
110+
if entry.name in dataarray.dims:
111+
dataarray.coords[entry.name] = entry(dataarray)
113112

114-
for item in model.coord.values():
115-
if item.name not in dataarray.dims:
116-
dataarray.coords[item.name] = item(dataarray)
113+
for entry in model.coords:
114+
if entry.name not in dataarray.dims:
115+
dataarray.coords[entry.name] = entry(dataarray)
117116

118-
for item in model.attr.values():
119-
dataarray.attrs[item.name] = item()
117+
for entry in model.attrs:
118+
dataarray.attrs[entry.name] = entry()
120119

121-
if model.name:
122-
item = next(iter(model.name.values()))
123-
dataarray.name = item()
120+
if model.names:
121+
dataarray.name = model.names[0]()
124122

125123
return dataarray
126124

@@ -213,12 +211,12 @@ def shaped(
213211
214212
"""
215213
model = DataModel.from_dataclass(cls)
216-
name, item = next(iter(model.data.items()))
214+
key, entry = model.data_vars_items[0]
217215

218216
if isinstance(shape, dict):
219-
shape = tuple(shape[dim] for dim in item.type["dims"])
217+
shape = tuple(shape[dim] for dim in entry.dims)
220218

221-
return asdataarray(cls(**{name: func(shape)}, **kwargs))
219+
return asdataarray(cls(**{key: func(shape)}, **kwargs))
222220

223221
@overload
224222
@classmethod

0 commit comments

Comments
 (0)