Skip to content

Commit a51be9e

Browse files
authored
Merge pull request #136 from astropenguin/astropenguin/issue135
Preferentially use data options if they are given in asdata* functions
2 parents d2a5b8e + bf8479e commit a51be9e

File tree

7 files changed

+109
-93
lines changed

7 files changed

+109
-93
lines changed

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,11 @@ class Custom(xr.DataArray):
311311
print("Custom method!")
312312

313313

314-
dataoptions = DataOptions(Custom)
315-
316-
317314
@dataclass
318315
class Image(AsDataArray):
319316
"""Specs for a monochromatic image."""
320317

321-
__dataoptions__ = dataoptions
318+
__dataoptions__ = DataOptions(Custom)
322319

323320
data: Data[tuple[X, Y], float]
324321
x: Coord[X, int] = 0

tests/test_dataarray.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,11 @@ class Custom(xr.DataArray):
3232
__slots__ = ()
3333

3434

35-
dataoptions = DataOptions(Custom)
36-
37-
3835
@dataclass
3936
class Image(AsDataArray):
4037
"""Specs for a monochromatic image."""
4138

42-
__dataoptions__ = dataoptions
39+
__dataoptions__ = DataOptions(Custom)
4340

4441
data: Data[Tuple[X, Y], float]
4542
x: Coord[X, int] = 0

tests/test_dataset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ class Custom(xr.Dataset):
3030
__slots__ = ()
3131

3232

33-
dataoptions = DataOptions(Custom)
34-
35-
3633
@dataclass
3734
class Image(AsDataArray):
3835
"""Specs for a monochromatic image."""
@@ -44,7 +41,7 @@ class Image(AsDataArray):
4441
class ColorImage(AsDataset):
4542
"""Specs for a color image."""
4643

47-
__dataoptions__ = dataoptions
44+
__dataoptions__ = DataOptions(Custom)
4845

4946
red: Data[Tuple[X, Y], float]
5047
green: Data[Tuple[X, Y], float]

xarray_dataclasses/dataarray.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22

33

44
# standard library
5-
from dataclasses import Field
65
from functools import partial, wraps
76
from types import MethodType
8-
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
7+
from typing import (
8+
Any,
9+
Callable,
10+
ClassVar,
11+
Optional,
12+
Type,
13+
TypeVar,
14+
Union,
15+
overload,
16+
)
917

1018

1119
# dependencies
@@ -18,60 +26,65 @@
1826
# submodules
1927
from .datamodel import DataModel
2028
from .dataoptions import DataOptions
21-
from .typing import DataType, Order, Shape, Sizes
22-
23-
24-
# constants
25-
DEFAULT_OPTIONS = DataOptions(xr.DataArray)
29+
from .typing import DataClass, DataClassFields, DataType, Order, Shape, Sizes
2630

2731

2832
# type hints
2933
P = ParamSpec("P")
3034
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
3135

3236

33-
class DataClass(Protocol[P]):
34-
"""Type hint for a dataclass object."""
37+
class OptionedClass(Protocol[P, TDataArray]):
38+
"""Type hint for dataclass objects with options."""
3539

3640
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
3741
...
3842

39-
__dataclass_fields__: Dict[str, Field[Any]]
43+
__dataclass_fields__: ClassVar[DataClassFields]
44+
__dataoptions__: DataOptions[TDataArray]
4045

4146

42-
class DataArrayClass(Protocol[P, TDataArray]):
43-
"""Type hint for a dataclass object with a DataArray factory."""
47+
# runtime functions
48+
@overload
49+
def asdataarray(
50+
dataclass: OptionedClass[P, TDataArray],
51+
reference: Optional[DataType] = None,
52+
dataoptions: None = None,
53+
) -> TDataArray:
54+
...
4455

45-
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
46-
...
4756

48-
__dataclass_fields__: Dict[str, Field[Any]]
49-
__dataoptions__: DataOptions[TDataArray]
57+
@overload
58+
def asdataarray(
59+
dataclass: DataClass[P],
60+
reference: Optional[DataType] = None,
61+
dataoptions: None = None,
62+
) -> xr.DataArray:
63+
...
5064

5165

52-
# runtime functions
5366
@overload
5467
def asdataarray(
55-
dataclass: DataArrayClass[Any, TDataArray],
68+
dataclass: OptionedClass[P, Any],
5669
reference: Optional[DataType] = None,
57-
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
70+
dataoptions: Optional[DataOptions[TDataArray]] = None,
5871
) -> TDataArray:
5972
...
6073

6174

6275
@overload
6376
def asdataarray(
64-
dataclass: DataClass[Any],
77+
dataclass: DataClass[P],
6578
reference: Optional[DataType] = None,
66-
dataoptions: DataOptions[TDataArray] = DEFAULT_OPTIONS,
79+
dataoptions: Optional[DataOptions[TDataArray]] = None,
6780
) -> TDataArray:
6881
...
6982

7083

7184
def asdataarray(
7285
dataclass: Any,
7386
reference: Optional[DataType] = None,
74-
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
87+
dataoptions: Any = None,
7588
) -> Any:
7689
"""Create a DataArray object from a dataclass object.
7790
@@ -84,10 +97,11 @@ def asdataarray(
8497
DataArray object created from the dataclass object.
8598
8699
"""
87-
try:
88-
dataoptions = dataclass.__dataoptions__
89-
except AttributeError:
90-
pass
100+
if dataoptions is None:
101+
try:
102+
dataoptions = dataclass.__dataoptions__
103+
except AttributeError:
104+
dataoptions = DataOptions(xr.DataArray)
91105

92106
model = DataModel.from_dataclass(dataclass)
93107
item = next(iter(model.data.values()))
@@ -127,7 +141,7 @@ def __init__(self, func: Any) -> None:
127141
def __get__(
128142
self,
129143
obj: Any,
130-
cls: Type[DataArrayClass[P, TDataArray]],
144+
cls: Type[OptionedClass[P, TDataArray]],
131145
) -> Callable[P, TDataArray]:
132146
...
133147

@@ -163,7 +177,7 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
163177
@overload
164178
@classmethod
165179
def shaped(
166-
cls: Type[DataArrayClass[P, TDataArray]],
180+
cls: Type[OptionedClass[P, TDataArray]],
167181
func: Callable[[Shape], np.ndarray],
168182
shape: Union[Shape, Sizes],
169183
**kwargs: Any,
@@ -209,7 +223,7 @@ def shaped(
209223
@overload
210224
@classmethod
211225
def empty(
212-
cls: Type[DataArrayClass[P, TDataArray]],
226+
cls: Type[OptionedClass[P, TDataArray]],
213227
shape: Union[Shape, Sizes],
214228
order: Order = "C",
215229
**kwargs: Any,
@@ -251,7 +265,7 @@ def empty(
251265
@overload
252266
@classmethod
253267
def zeros(
254-
cls: Type[DataArrayClass[P, TDataArray]],
268+
cls: Type[OptionedClass[P, TDataArray]],
255269
shape: Union[Shape, Sizes],
256270
order: Order = "C",
257271
**kwargs: Any,
@@ -293,7 +307,7 @@ def zeros(
293307
@overload
294308
@classmethod
295309
def ones(
296-
cls: Type[DataArrayClass[P, TDataArray]],
310+
cls: Type[OptionedClass[P, TDataArray]],
297311
shape: Union[Shape, Sizes],
298312
order: Order = "C",
299313
**kwargs: Any,
@@ -335,7 +349,7 @@ def ones(
335349
@overload
336350
@classmethod
337351
def full(
338-
cls: Type[DataArrayClass[P, TDataArray]],
352+
cls: Type[OptionedClass[P, TDataArray]],
339353
shape: Union[Shape, Sizes],
340354
fill_value: Any,
341355
order: Order = "C",

xarray_dataclasses/datamodel.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
# standard library
55
from dataclasses import Field, dataclass, field, is_dataclass
6-
from typing import Any, Dict, Hashable, Optional, Type, cast
6+
from typing import Any, Dict, Hashable, Optional, Type, Union, cast
77

88

99
# dependencies
1010
import numpy as np
1111
import xarray as xr
12-
from typing_extensions import TypedDict, get_type_hints
12+
from typing_extensions import ParamSpec, TypedDict, get_type_hints
1313

1414

1515
# submodules
@@ -28,6 +28,8 @@
2828

2929

3030
# type hints
31+
P = ParamSpec("P")
32+
AnyDataClass = Union[Type[DataClass[P]], DataClass[P]]
3133
DimsDtype = TypedDict("DimsDtype", dims=Dims, dtype=Dtype)
3234

3335

@@ -45,7 +47,7 @@ class Data:
4547
type: DimsDtype
4648
"""Type (dims and dtype) of the field."""
4749

48-
factory: Optional[Type[DataClass]] = None
50+
factory: Any = None
4951
"""Factory dataclass to create a DataArray object."""
5052

5153
def __call__(self, reference: Optional[DataType] = None) -> xr.DataArray:
@@ -137,7 +139,7 @@ class DataModel:
137139
"""Model of the name fields."""
138140

139141
@classmethod
140-
def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
142+
def from_dataclass(cls, dataclass: AnyDataClass[P]) -> "DataModel":
141143
"""Create a data model from a dataclass or its object."""
142144
model = cls()
143145
eval_field_types(dataclass)
@@ -162,7 +164,7 @@ def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
162164

163165

164166
# runtime functions
165-
def eval_field_types(dataclass: DataClass) -> None:
167+
def eval_field_types(dataclass: AnyDataClass[P]) -> None:
166168
"""Evaluate field types of a dataclass or its object."""
167169
hints = get_type_hints(dataclass, include_extras=True) # type: ignore
168170

0 commit comments

Comments
 (0)