Skip to content

Commit 2a2e84b

Browse files
authored
#158 Merge pull request from astropenguin/astropenguin/issue157
Update typing module
2 parents f448dc0 + 75a7dd1 commit 2a2e84b

File tree

8 files changed

+327
-356
lines changed

8 files changed

+327
-356
lines changed

poetry.lock

Lines changed: 106 additions & 117 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ documentation = "https://astropenguin.github.io/xarray-dataclasses/"
1212
[tool.poetry.dependencies]
1313
python = ">=3.7.1, <3.11"
1414
morecopy = "^0.2"
15-
more-itertools = "^8.12"
1615
numpy = [
1716
{ version = ">=1.15, <1.22", python = ">=3.7.1, <3.8" },
1817
{ version = "^1.15", python = ">=3.8, <3.11" },
@@ -27,7 +26,7 @@ xarray = [
2726
black = "^22.3"
2827
ipython = [
2928
{ version = "^7.32", python = ">=3.7.1, <3.8" },
30-
{ version = "^8.2", python = ">=3.8, <3.11" },
29+
{ version = "^8.4", python = ">=3.8, <3.11" },
3130
]
3231
myst-parser = "^0.17"
3332
pydata-sphinx-theme = "^0.8"

tests/test_typing.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,88 @@
11
# standard library
2-
from typing import Any, Optional, Tuple, Union
2+
from typing import Any, Tuple, Union
33

44

55
# dependencies
6+
import numpy as np
67
from pytest import mark
7-
from typing_extensions import Annotated, Literal
8+
from typing_extensions import Annotated as Ann
9+
from typing_extensions import Literal as L
810

911

1012
# submodules
1113
from xarray_dataclasses.typing import (
1214
Attr,
13-
Collection,
1415
Coord,
1516
Data,
17+
FType,
1618
Name,
1719
get_dims,
1820
get_dtype,
19-
get_field_type,
20-
get_repr_type,
21+
get_ftype,
2122
)
2223

2324

24-
# type hints
25-
Int64 = Literal["int64"]
26-
NoneType = type(None)
27-
X = Literal["x"]
28-
Y = Literal["y"]
29-
30-
3125
# test datasets
3226
testdata_dims = [
33-
(X, ("x",)),
34-
(Tuple[()], ()),
35-
(Tuple[X], ("x",)),
36-
(Tuple[X, Y], ("x", "y")),
37-
(Collection[X, Any], ("x",)),
38-
(Collection[Tuple[()], Any], ()),
39-
(Collection[Tuple[X], Any], ("x",)),
40-
(Collection[Tuple[X, Y], Any], ("x", "y")),
27+
(Coord[Tuple[()], Any], ()),
28+
(Coord[L["x"], Any], ("x",)),
29+
(Coord[Tuple[L["x"]], Any], ("x",)),
30+
(Coord[Tuple[L["x"], L["y"]], Any], ("x", "y")),
31+
(Data[Tuple[()], Any], ()),
32+
(Data[L["x"], Any], ("x",)),
33+
(Data[Tuple[L["x"]], Any], ("x",)),
34+
(Data[Tuple[L["x"], L["y"]], Any], ("x", "y")),
35+
(Ann[Coord[L["x"], Any], "coord"], ("x",)),
36+
(Ann[Data[L["x"], Any], "data"], ("x",)),
37+
(Union[Ann[Coord[L["x"], Any], "coord"], Ann[Any, "any"]], ("x",)),
38+
(Union[Ann[Data[L["x"], Any], "data"], Ann[Any, "any"]], ("x",)),
4139
]
4240

4341
testdata_dtype = [
44-
(Any, None),
45-
(NoneType, None),
46-
(Int64, "int64"),
47-
(int, "int"),
48-
(Collection[Any, Any], None),
49-
(Collection[Any, NoneType], None),
50-
(Collection[Any, Int64], "int64"),
51-
(Collection[Any, int], "int"),
52-
]
53-
54-
testdata_field_type = [
55-
(Attr[Any], "attr"),
56-
(Coord[Any, Any], "coord"),
57-
(Data[Any, Any], "data"),
58-
(Name[Any], "name"),
42+
(Coord[Any, Any], None),
43+
(Coord[Any, None], None),
44+
(Coord[Any, int], np.dtype("i8")),
45+
(Coord[Any, L["i8"]], np.dtype("i8")),
46+
(Data[Any, Any], None),
47+
(Data[Any, None], None),
48+
(Data[Any, int], np.dtype("i8")),
49+
(Data[Any, L["i8"]], np.dtype("i8")),
50+
(Ann[Coord[Any, float], "coord"], np.dtype("f8")),
51+
(Ann[Data[Any, float], "data"], np.dtype("f8")),
52+
(Union[Ann[Coord[Any, float], "coord"], Ann[Any, "any"]], np.dtype("f8")),
53+
(Union[Ann[Data[Any, float], "data"], Ann[Any, "any"]], np.dtype("f8")),
5954
]
6055

61-
testdata_repr_type = [
62-
(int, int),
63-
(Annotated[int, "annotation"], int),
64-
(Union[int, float], int),
65-
(Optional[int], int),
56+
testdata_ftype = [
57+
(Attr[Any], FType.ATTR),
58+
(Data[Any, Any], FType.DATA),
59+
(Coord[Any, Any], FType.COORD),
60+
(Name[Any], FType.NAME),
61+
(Any, FType.OTHER),
62+
(Ann[Attr[Any], "attr"], FType.ATTR),
63+
(Ann[Data[Any, Any], "data"], FType.DATA),
64+
(Ann[Coord[Any, Any], "coord"], FType.COORD),
65+
(Ann[Name[Any], "name"], FType.NAME),
66+
(Ann[Any, "other"], FType.OTHER),
67+
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], FType.ATTR),
68+
(Union[Ann[Data[Any, Any], "data"], Ann[Any, "any"]], FType.DATA),
69+
(Union[Ann[Coord[Any, Any], "coord"], Ann[Any, "any"]], FType.COORD),
70+
(Union[Ann[Name[Any], "name"], Ann[Any, "any"]], FType.NAME),
71+
(Union[Ann[Any, "other"], Ann[Any, "any"]], FType.OTHER),
6672
]
6773

6874

6975
# test functions
70-
@mark.parametrize("type_, dims", testdata_dims)
71-
def test_get_dims(type_: Any, dims: Any) -> None:
72-
assert get_dims(type_) == dims
73-
74-
75-
@mark.parametrize("type_, dtype", testdata_dtype)
76-
def test_get_dtype(type_: Any, dtype: Any) -> None:
77-
assert get_dtype(type_) == dtype
76+
@mark.parametrize("tp, dims", testdata_dims)
77+
def test_get_dims(tp: Any, dims: Any) -> None:
78+
assert get_dims(tp) == dims
7879

7980

80-
@mark.parametrize("type_, field_type", testdata_field_type)
81-
def test_get_field_type(type_: Any, field_type: Any) -> None:
82-
assert get_field_type(type_).value == field_type
81+
@mark.parametrize("tp, dtype", testdata_dtype)
82+
def test_get_dtype(tp: Any, dtype: Any) -> None:
83+
assert get_dtype(tp) == dtype
8384

8485

85-
@mark.parametrize("type_, repr_type", testdata_repr_type)
86-
def test_get_repr_type(type_: Any, repr_type: Any) -> None:
87-
assert get_repr_type(type_) == repr_type
86+
@mark.parametrize("tp, ftype", testdata_ftype)
87+
def test_get_ftype(tp: Any, ftype: Any) -> None:
88+
assert get_ftype(tp) == ftype

xarray_dataclasses/dataarray.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,7 @@
55
# standard library
66
from functools import partial, wraps
77
from types import MethodType
8-
from typing import (
9-
Any,
10-
Callable,
11-
ClassVar,
12-
Optional,
13-
Type,
14-
TypeVar,
15-
Union,
16-
overload,
17-
)
8+
from typing import Any, Callable, Optional, Type, TypeVar, Union, overload
189

1910

2011
# dependencies
@@ -27,29 +18,25 @@
2718
# submodules
2819
from .datamodel import DataModel
2920
from .dataoptions import DataOptions
30-
from .typing import AnyArray, DataClass, DataClassFields, DataType, Order, Shape, Sizes
21+
from .typing import AnyArray, AnyXarray, DataClass, Order, Shape, Sizes
3122

3223

3324
# type hints
3425
PInit = ParamSpec("PInit")
3526
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
3627

3728

38-
class OptionedClass(Protocol[PInit, TDataArray]):
29+
class OptionedClass(DataClass[PInit], Protocol[PInit, TDataArray]):
3930
"""Type hint for dataclass objects with options."""
4031

41-
def __init__(self, *args: PInit.args, **kwargs: PInit.kwargs) -> None:
42-
...
43-
44-
__dataclass_fields__: ClassVar[DataClassFields]
4532
__dataoptions__: DataOptions[TDataArray]
4633

4734

4835
# runtime functions
4936
@overload
5037
def asdataarray(
5138
dataclass: OptionedClass[PInit, TDataArray],
52-
reference: Optional[DataType] = None,
39+
reference: Optional[AnyXarray] = None,
5340
dataoptions: None = None,
5441
) -> TDataArray:
5542
...
@@ -58,7 +45,7 @@ def asdataarray(
5845
@overload
5946
def asdataarray(
6047
dataclass: DataClass[PInit],
61-
reference: Optional[DataType] = None,
48+
reference: Optional[AnyXarray] = None,
6249
dataoptions: None = None,
6350
) -> xr.DataArray:
6451
...
@@ -67,15 +54,15 @@ def asdataarray(
6754
@overload
6855
def asdataarray(
6956
dataclass: Any,
70-
reference: Optional[DataType] = None,
57+
reference: Optional[AnyXarray] = None,
7158
dataoptions: DataOptions[TDataArray] = DataOptions(xr.DataArray),
7259
) -> TDataArray:
7360
...
7461

7562

7663
def asdataarray(
7764
dataclass: Any,
78-
reference: Optional[DataType] = None,
65+
reference: Optional[AnyXarray] = None,
7966
dataoptions: Any = None,
8067
) -> Any:
8168
"""Create a DataArray object from a dataclass object.

xarray_dataclasses/datamodel.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515

1616
# submodules
1717
from .typing import (
18+
AnyDType,
1819
AnyField,
1920
DataClass,
20-
DataType,
21+
AnyXarray,
2122
Dims,
22-
Dtype,
23-
FieldType,
23+
FType,
24+
get_annotated,
25+
get_dataclass,
2426
get_dims,
2527
get_dtype,
26-
get_field_type,
27-
get_repr_type,
28+
get_ftype,
2829
)
2930

3031

@@ -93,7 +94,7 @@ class DataEntry:
9394
dims: Dims = cast(Dims, None)
9495
"""Dimensions of the DataArray that the data is cast to."""
9596

96-
dtype: Dtype = cast(Dtype, None)
97+
dtype: Optional[AnyDType] = None
9798
"""Data type of the DataArray that the data is cast to."""
9899

99100
base: Optional[Type[DataClass[Any]]] = None
@@ -119,7 +120,7 @@ def __post_init__(self) -> None:
119120
if model.names:
120121
setattr(self, "name", model.names[0].value)
121122

122-
def __call__(self, reference: Optional[DataType] = None) -> xr.DataArray:
123+
def __call__(self, reference: Optional[AnyXarray] = None) -> xr.DataArray:
123124
"""Create a DataArray object according to the entry."""
124125
from .dataarray import asdataarray
125126

@@ -174,11 +175,11 @@ def from_dataclass(cls, dataclass: AnyDataClass[PInit]) -> "DataModel":
174175
eval_dataclass(dataclass)
175176

176177
for field in dataclass.__dataclass_fields__.values():
177-
try:
178-
value = getattr(dataclass, field.name, MISSING)
179-
model.entries[field.name] = get_entry(field, value)
180-
except TypeError:
181-
pass
178+
value = getattr(dataclass, field.name, MISSING)
179+
entry = get_entry(field, value)
180+
181+
if entry is not None:
182+
model.entries[field.name] = entry
182183

183184
return model
184185

@@ -205,42 +206,41 @@ def eval_dataclass(dataclass: AnyDataClass[PInit]) -> None:
205206
field.type = types[field.name]
206207

207208

208-
def get_entry(field: AnyField, value: Any) -> AnyEntry:
209+
def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]:
209210
"""Create an entry from a field and its value."""
210-
field_type = get_field_type(field.type)
211-
repr_type = get_repr_type(field.type)
211+
ftype = get_ftype(field.type)
212212

213-
if field_type is FieldType.ATTR or field_type is FieldType.NAME:
213+
if ftype is FType.ATTR or ftype is FType.NAME:
214214
return AttrEntry(
215215
name=field.name,
216-
tag=field_type.value,
216+
tag=ftype.value,
217217
value=value,
218-
type=repr_type,
218+
type=get_annotated(field.type),
219219
)
220220

221-
# hereafter field type is either COORD or DATA
222-
if is_dataclass(repr_type):
223-
return DataEntry(
224-
name=field.name,
225-
tag=field_type.value,
226-
base=repr_type,
227-
value=value,
228-
)
229-
else:
230-
return DataEntry(
231-
name=field.name,
232-
tag=field_type.value,
233-
dims=get_dims(repr_type),
234-
dtype=get_dtype(repr_type),
235-
value=value,
236-
)
221+
if ftype is FType.COORD or ftype is FType.DATA:
222+
try:
223+
return DataEntry(
224+
name=field.name,
225+
tag=ftype.value,
226+
base=get_dataclass(field.type),
227+
value=value,
228+
)
229+
except TypeError:
230+
return DataEntry(
231+
name=field.name,
232+
tag=ftype.value,
233+
dims=get_dims(field.type),
234+
dtype=get_dtype(field.type),
235+
value=value,
236+
)
237237

238238

239239
def get_typedarray(
240240
data: Any,
241241
dims: Dims,
242-
dtype: Dtype,
243-
reference: Optional[DataType] = None,
242+
dtype: Optional[AnyDType],
243+
reference: Optional[AnyXarray] = None,
244244
) -> xr.DataArray:
245245
"""Create a DataArray object with given dims and dtype.
246246

xarray_dataclasses/dataoptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88

99

1010
# submodules
11-
from .typing import DataType
11+
from .typing import AnyXarray
1212

1313

1414
# type hints
15-
TDataType = TypeVar("TDataType", bound=DataType)
15+
TAnyXarray = TypeVar("TAnyXarray", bound=AnyXarray)
1616

1717

1818
# dataclasses
1919
@dataclass(frozen=True)
20-
class DataOptions(Generic[TDataType]):
20+
class DataOptions(Generic[TAnyXarray]):
2121
"""Options for DataArray or Dataset creation."""
2222

23-
factory: Callable[..., TDataType]
23+
factory: Callable[..., TAnyXarray]
2424
"""Factory function for DataArray or Dataset."""

0 commit comments

Comments
 (0)