Skip to content

Commit b5f4b04

Browse files
authored
Merge pull request #122 from astropenguin/astropenguin/issue121
Fix mix-in classes
2 parents a9a29cf + 19e87d6 commit b5f4b04

File tree

4 files changed

+265
-83
lines changed

4 files changed

+265
-83
lines changed

package-lock.json

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

xarray_dataclasses/dataarray.py

Lines changed: 129 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,50 +28,33 @@
2828
# type hints
2929
P = ParamSpec("P")
3030
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
31-
TDataArray_ = TypeVar("TDataArray_", bound=xr.DataArray, contravariant=True)
3231

3332

3433
class DataClass(Protocol[P]):
3534
"""Type hint for a dataclass object."""
3635

37-
__init__: Callable[P, None]
38-
__dataclass_fields__: Dict[str, Field[Any]]
39-
40-
41-
class DataArrayClass(Protocol[P, TDataArray_]):
42-
"""Type hint for a dataclass object with a DataArray factory."""
36+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
37+
...
4338

44-
__init__: Callable[P, None]
4539
__dataclass_fields__: Dict[str, Field[Any]]
46-
__dataoptions__: DataOptions[TDataArray_]
4740

4841

49-
# custom classproperty
50-
class classproperty:
51-
"""Class property only for AsDataArray.new().
52-
53-
As a classmethod and a property can be chained together since Python 3.9,
54-
this class will be removed when the support for Python 3.7 and 3.8 ends.
55-
56-
"""
42+
class DataArrayClass(Protocol[P, TDataArray]):
43+
"""Type hint for a dataclass object with a DataArray factory."""
5744

58-
def __init__(self, func: Callable[..., Callable[P, TDataArray]]) -> None:
59-
self.__func__ = func
45+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
46+
...
6047

61-
def __get__(
62-
self,
63-
obj: Any,
64-
cls: Type[DataArrayClass[P, TDataArray]],
65-
) -> Callable[P, TDataArray]:
66-
return self.__func__(cls)
48+
__dataclass_fields__: Dict[str, Field[Any]]
49+
__dataoptions__: DataOptions[TDataArray]
6750

6851

69-
# runtime functions and classes
52+
# runtime functions
7053
@overload
7154
def asdataarray(
7255
dataclass: DataArrayClass[Any, TDataArray],
7356
reference: Optional[DataType] = None,
74-
dataoptions: Any = DEFAULT_OPTIONS,
57+
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
7558
) -> TDataArray:
7659
...
7760

@@ -87,8 +70,8 @@ def asdataarray(
8770

8871
def asdataarray(
8972
dataclass: Any,
90-
reference: Any = None,
91-
dataoptions: Any = DEFAULT_OPTIONS,
73+
reference: Optional[DataType] = None,
74+
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
9275
) -> Any:
9376
"""Create a DataArray object from a dataclass object.
9477
@@ -134,36 +117,82 @@ def asdataarray(
134117
return dataarray
135118

136119

120+
# runtime classes
121+
class classproperty:
122+
"""Class property only for AsDataArray.new().
123+
124+
As a classmethod and a property can be chained together since Python 3.9,
125+
this class will be removed when the support for Python 3.7 and 3.8 ends.
126+
127+
"""
128+
129+
def __init__(self, func: Any) -> None:
130+
self.__func__ = func
131+
132+
@overload
133+
def __get__(
134+
self,
135+
obj: Any,
136+
cls: Type[DataArrayClass[P, TDataArray]],
137+
) -> Callable[P, TDataArray]:
138+
...
139+
140+
@overload
141+
def __get__(
142+
self,
143+
obj: Any,
144+
cls: Type[DataClass[P]],
145+
) -> Callable[P, xr.DataArray]:
146+
...
147+
148+
def __get__(self, obj: Any, cls: Any) -> Any:
149+
return self.__func__(cls)
150+
151+
137152
class AsDataArray:
138153
"""Mix-in class that provides shorthand methods."""
139154

140-
__dataoptions__ = DEFAULT_OPTIONS
141-
142155
@classproperty
143-
def new(cls: Type[DataArrayClass[P, TDataArray]]) -> Callable[P, TDataArray]:
156+
def new(cls: Any) -> Any:
144157
"""Create a DataArray object from dataclass parameters."""
145158

146159
init = copy(cls.__init__)
147-
init.__annotations__["return"] = TDataArray
148160
init.__doc__ = cls.__init__.__doc__
161+
init.__annotations__["return"] = TDataArray
149162

150163
@wraps(init)
151-
def new(
152-
cls: Type[DataArrayClass[P, TDataArray]],
153-
*args: P.args,
154-
**kwargs: P.kwargs,
155-
) -> TDataArray:
164+
def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
156165
return asdataarray(cls(*args, **kwargs))
157166

158167
return MethodType(new, cls)
159168

169+
@overload
160170
@classmethod
161171
def empty(
162172
cls: Type[DataArrayClass[P, TDataArray]],
163173
shape: Union[Shape, Sizes],
164174
order: Order = "C",
165175
**kwargs: Any,
166176
) -> TDataArray:
177+
...
178+
179+
@overload
180+
@classmethod
181+
def empty(
182+
cls: Type[DataClass[P]],
183+
shape: Union[Shape, Sizes],
184+
order: Order = "C",
185+
**kwargs: Any,
186+
) -> xr.DataArray:
187+
...
188+
189+
@classmethod
190+
def empty(
191+
cls: Any,
192+
shape: Union[Shape, Sizes],
193+
order: Order = "C",
194+
**kwargs: Any,
195+
) -> Any:
167196
"""Create a DataArray object without initializing data.
168197
169198
Args:
@@ -185,13 +214,33 @@ def empty(
185214
data = np.empty(shape, order=order)
186215
return asdataarray(cls(**{name: data}, **kwargs))
187216

217+
@overload
188218
@classmethod
189219
def zeros(
190220
cls: Type[DataArrayClass[P, TDataArray]],
191221
shape: Union[Shape, Sizes],
192222
order: Order = "C",
193223
**kwargs: Any,
194224
) -> TDataArray:
225+
...
226+
227+
@overload
228+
@classmethod
229+
def zeros(
230+
cls: Type[DataClass[P]],
231+
shape: Union[Shape, Sizes],
232+
order: Order = "C",
233+
**kwargs: Any,
234+
) -> xr.DataArray:
235+
...
236+
237+
@classmethod
238+
def zeros(
239+
cls: Any,
240+
shape: Union[Shape, Sizes],
241+
order: Order = "C",
242+
**kwargs: Any,
243+
) -> Any:
195244
"""Create a DataArray object filled with zeros.
196245
197246
Args:
@@ -213,13 +262,33 @@ def zeros(
213262
data = np.zeros(shape, order=order)
214263
return asdataarray(cls(**{name: data}, **kwargs))
215264

265+
@overload
216266
@classmethod
217267
def ones(
218268
cls: Type[DataArrayClass[P, TDataArray]],
219269
shape: Union[Shape, Sizes],
220270
order: Order = "C",
221271
**kwargs: Any,
222272
) -> TDataArray:
273+
...
274+
275+
@overload
276+
@classmethod
277+
def ones(
278+
cls: Type[DataClass[P]],
279+
shape: Union[Shape, Sizes],
280+
order: Order = "C",
281+
**kwargs: Any,
282+
) -> xr.DataArray:
283+
...
284+
285+
@classmethod
286+
def ones(
287+
cls: Any,
288+
shape: Union[Shape, Sizes],
289+
order: Order = "C",
290+
**kwargs: Any,
291+
) -> Any:
223292
"""Create a DataArray object filled with ones.
224293
225294
Args:
@@ -241,6 +310,7 @@ def ones(
241310
data = np.ones(shape, order=order)
242311
return asdataarray(cls(**{name: data}, **kwargs))
243312

313+
@overload
244314
@classmethod
245315
def full(
246316
cls: Type[DataArrayClass[P, TDataArray]],
@@ -249,6 +319,27 @@ def full(
249319
order: Order = "C",
250320
**kwargs: Any,
251321
) -> TDataArray:
322+
...
323+
324+
@overload
325+
@classmethod
326+
def full(
327+
cls: Type[DataClass[P]],
328+
shape: Union[Shape, Sizes],
329+
fill_value: Any,
330+
order: Order = "C",
331+
**kwargs: Any,
332+
) -> xr.DataArray:
333+
...
334+
335+
@classmethod
336+
def full(
337+
cls: Any,
338+
shape: Union[Shape, Sizes],
339+
fill_value: Any,
340+
order: Order = "C",
341+
**kwargs: Any,
342+
) -> Any:
252343
"""Create a DataArray object filled with given value.
253344
254345
Args:

0 commit comments

Comments
 (0)