Skip to content

Commit 96d058d

Browse files
committed
#121 Use overload in AsDataArray
1 parent 507e2c3 commit 96d058d

File tree

1 file changed

+122
-34
lines changed

1 file changed

+122
-34
lines changed

xarray_dataclasses/dataarray.py

Lines changed: 122 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,12 @@ def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
4949
__dataoptions__: DataOptions[TDataArray]
5050

5151

52-
# custom classproperty
53-
class classproperty:
54-
"""Class property only for AsDataArray.new().
55-
56-
As a classmethod and a property can be chained together since Python 3.9,
57-
this class will be removed when the support for Python 3.7 and 3.8 ends.
58-
59-
"""
60-
61-
def __init__(self, func: Callable[..., Any]) -> None:
62-
self.__func__ = func
63-
64-
def __get__(
65-
self,
66-
obj: Any,
67-
cls: Type[DataArrayClass[P, TDataArray]],
68-
) -> Callable[P, TDataArray]:
69-
return self.__func__(cls)
70-
71-
72-
# runtime functions and classes
52+
# runtime functions
7353
@overload
7454
def asdataarray(
7555
dataclass: DataArrayClass[Any, TDataArray],
7656
reference: Optional[DataType] = None,
77-
dataoptions: Any = DEFAULT_OPTIONS,
57+
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
7858
) -> TDataArray:
7959
...
8060

@@ -90,8 +70,8 @@ def asdataarray(
9070

9171
def asdataarray(
9272
dataclass: Any,
93-
reference: Any = None,
94-
dataoptions: Any = DEFAULT_OPTIONS,
73+
reference: Optional[DataType] = None,
74+
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
9575
) -> Any:
9676
"""Create a DataArray object from a dataclass object.
9777
@@ -137,36 +117,82 @@ def asdataarray(
137117
return dataarray
138118

139119

120+
# runtime classes
121+
class classproperty:
122+
"""Class property only for AsDataArray.new().
123+
124+
As a classmethod and a property can be chained together since Python 3.9,
125+
this class will be removed when the support for Python 3.7 and 3.8 ends.
126+
127+
"""
128+
129+
def __init__(self, func: Any) -> None:
130+
self.__func__ = func
131+
132+
@overload
133+
def __get__(
134+
self,
135+
obj: Any,
136+
cls: Type[DataArrayClass[P, TDataArray]],
137+
) -> Callable[P, TDataArray]:
138+
...
139+
140+
@overload
141+
def __get__(
142+
self,
143+
obj: Any,
144+
cls: Type[DataClass[P]],
145+
) -> Callable[P, xr.DataArray]:
146+
...
147+
148+
def __get__(self, obj: Any, cls: Any) -> Any:
149+
return self.__func__(cls)
150+
151+
140152
class AsDataArray:
141153
"""Mix-in class that provides shorthand methods."""
142154

143-
__dataoptions__ = DEFAULT_OPTIONS
144-
145155
@classproperty
146-
def new(cls: Type[DataArrayClass[P, TDataArray]]) -> Callable[P, TDataArray]:
156+
def new(cls: Any) -> Any:
147157
"""Create a DataArray object from dataclass parameters."""
148158

149-
init = copy(cls.__init__) # type: ignore
150-
init.__doc__ = cls.__init__.__doc__ # type: ignore
159+
init = copy(cls.__init__)
160+
init.__doc__ = cls.__init__.__doc__
151161
init.__annotations__["return"] = TDataArray
152162

153163
@wraps(init)
154-
def new(
155-
cls: Type[DataArrayClass[P, TDataArray]],
156-
*args: P.args,
157-
**kwargs: P.kwargs,
158-
) -> TDataArray:
164+
def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
159165
return asdataarray(cls(*args, **kwargs))
160166

161167
return MethodType(new, cls)
162168

169+
@overload
163170
@classmethod
164171
def empty(
165172
cls: Type[DataArrayClass[P, TDataArray]],
166173
shape: Union[Shape, Sizes],
167174
order: Order = "C",
168175
**kwargs: Any,
169176
) -> TDataArray:
177+
...
178+
179+
@overload
180+
@classmethod
181+
def empty(
182+
cls: Type[DataClass[P]],
183+
shape: Union[Shape, Sizes],
184+
order: Order = "C",
185+
**kwargs: Any,
186+
) -> xr.DataArray:
187+
...
188+
189+
@classmethod
190+
def empty(
191+
cls: Any,
192+
shape: Union[Shape, Sizes],
193+
order: Order = "C",
194+
**kwargs: Any,
195+
) -> Any:
170196
"""Create a DataArray object without initializing data.
171197
172198
Args:
@@ -188,13 +214,33 @@ def empty(
188214
data = np.empty(shape, order=order)
189215
return asdataarray(cls(**{name: data}, **kwargs))
190216

217+
@overload
191218
@classmethod
192219
def zeros(
193220
cls: Type[DataArrayClass[P, TDataArray]],
194221
shape: Union[Shape, Sizes],
195222
order: Order = "C",
196223
**kwargs: Any,
197224
) -> TDataArray:
225+
...
226+
227+
@overload
228+
@classmethod
229+
def zeros(
230+
cls: Type[DataClass[P]],
231+
shape: Union[Shape, Sizes],
232+
order: Order = "C",
233+
**kwargs: Any,
234+
) -> xr.DataArray:
235+
...
236+
237+
@classmethod
238+
def zeros(
239+
cls: Any,
240+
shape: Union[Shape, Sizes],
241+
order: Order = "C",
242+
**kwargs: Any,
243+
) -> Any:
198244
"""Create a DataArray object filled with zeros.
199245
200246
Args:
@@ -216,13 +262,33 @@ def zeros(
216262
data = np.zeros(shape, order=order)
217263
return asdataarray(cls(**{name: data}, **kwargs))
218264

265+
@overload
219266
@classmethod
220267
def ones(
221268
cls: Type[DataArrayClass[P, TDataArray]],
222269
shape: Union[Shape, Sizes],
223270
order: Order = "C",
224271
**kwargs: Any,
225272
) -> TDataArray:
273+
...
274+
275+
@overload
276+
@classmethod
277+
def ones(
278+
cls: Type[DataClass[P]],
279+
shape: Union[Shape, Sizes],
280+
order: Order = "C",
281+
**kwargs: Any,
282+
) -> xr.DataArray:
283+
...
284+
285+
@classmethod
286+
def ones(
287+
cls: Any,
288+
shape: Union[Shape, Sizes],
289+
order: Order = "C",
290+
**kwargs: Any,
291+
) -> Any:
226292
"""Create a DataArray object filled with ones.
227293
228294
Args:
@@ -244,6 +310,7 @@ def ones(
244310
data = np.ones(shape, order=order)
245311
return asdataarray(cls(**{name: data}, **kwargs))
246312

313+
@overload
247314
@classmethod
248315
def full(
249316
cls: Type[DataArrayClass[P, TDataArray]],
@@ -252,6 +319,27 @@ def full(
252319
order: Order = "C",
253320
**kwargs: Any,
254321
) -> TDataArray:
322+
...
323+
324+
@overload
325+
@classmethod
326+
def full(
327+
cls: Type[DataClass[P]],
328+
shape: Union[Shape, Sizes],
329+
fill_value: Any,
330+
order: Order = "C",
331+
**kwargs: Any,
332+
) -> xr.DataArray:
333+
...
334+
335+
@classmethod
336+
def full(
337+
cls: Any,
338+
shape: Union[Shape, Sizes],
339+
fill_value: Any,
340+
order: Order = "C",
341+
**kwargs: Any,
342+
) -> Any:
255343
"""Create a DataArray object filled with given value.
256344
257345
Args:

0 commit comments

Comments
 (0)