Skip to content

Commit 71888c3

Browse files
committed
#123 Add shaped to AsDataArray
1 parent 483fdaa commit 71888c3

File tree

1 file changed

+55
-33
lines changed

1 file changed

+55
-33
lines changed

xarray_dataclasses/dataarray.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field
6-
from functools import wraps
6+
from functools import partial, wraps
77
from types import MethodType
88
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
99

@@ -166,6 +166,52 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
166166

167167
return MethodType(new, cls)
168168

169+
@overload
170+
@classmethod
171+
def shaped(
172+
cls: Type[DataArrayClass[P, TDataArray]],
173+
func: Callable[[Shape], np.ndarray],
174+
shape: Union[Shape, Sizes],
175+
**kwargs: Any,
176+
) -> TDataArray:
177+
...
178+
179+
@overload
180+
@classmethod
181+
def shaped(
182+
cls: Type[DataClass[P]],
183+
func: Callable[[Shape], np.ndarray],
184+
shape: Union[Shape, Sizes],
185+
**kwargs: Any,
186+
) -> xr.DataArray:
187+
...
188+
189+
@classmethod
190+
def shaped(
191+
cls: Any,
192+
func: Callable[[Shape], np.ndarray],
193+
shape: Union[Shape, Sizes],
194+
**kwargs: Any,
195+
) -> Any:
196+
"""Create a DataArray object from a shaped function.
197+
198+
Args:
199+
func: Function to create an array with given shape.
200+
shape: Shape or sizes of the new DataArray object.
201+
kwargs: Args of the DataArray class except for data.
202+
203+
Returns:
204+
DataArray object created from the shaped function.
205+
206+
"""
207+
model = DataModel.from_dataclass(cls)
208+
name, item = next(iter(model.data.items()))
209+
210+
if isinstance(shape, dict):
211+
shape = tuple(shape[dim] for dim in item.type["dims"])
212+
213+
return asdataarray(cls(**{name: func(shape)}, **kwargs))
214+
169215
@overload
170216
@classmethod
171217
def empty(
@@ -205,14 +251,8 @@ def empty(
205251
DataArray object without initializing data.
206252
207253
"""
208-
model = DataModel.from_dataclass(cls)
209-
name, item = next(iter(model.data.items()))
210-
211-
if isinstance(shape, dict):
212-
shape = tuple(shape[dim] for dim in item.type["dims"])
213-
214-
data = np.empty(shape, order=order)
215-
return asdataarray(cls(**{name: data}, **kwargs))
254+
func = partial(np.empty, order=order)
255+
return cls.shaped(func, shape, **kwargs)
216256

217257
@overload
218258
@classmethod
@@ -253,14 +293,8 @@ def zeros(
253293
DataArray object filled with zeros.
254294
255295
"""
256-
model = DataModel.from_dataclass(cls)
257-
name, item = next(iter(model.data.items()))
258-
259-
if isinstance(shape, dict):
260-
shape = tuple(shape[dim] for dim in item.type["dims"])
261-
262-
data = np.zeros(shape, order=order)
263-
return asdataarray(cls(**{name: data}, **kwargs))
296+
func = partial(np.zeros, order=order)
297+
return cls.shaped(func, shape, **kwargs)
264298

265299
@overload
266300
@classmethod
@@ -301,14 +335,8 @@ def ones(
301335
DataArray object filled with ones.
302336
303337
"""
304-
model = DataModel.from_dataclass(cls)
305-
name, item = next(iter(model.data.items()))
306-
307-
if isinstance(shape, dict):
308-
shape = tuple(shape[dim] for dim in item.type["dims"])
309-
310-
data = np.ones(shape, order=order)
311-
return asdataarray(cls(**{name: data}, **kwargs))
338+
func = partial(np.ones, order=order)
339+
return cls.shaped(func, shape, **kwargs)
312340

313341
@overload
314342
@classmethod
@@ -353,11 +381,5 @@ def full(
353381
DataArray object filled with given value.
354382
355383
"""
356-
model = DataModel.from_dataclass(cls)
357-
name, item = next(iter(model.data.items()))
358-
359-
if isinstance(shape, dict):
360-
shape = tuple(shape[dim] for dim in item.type["dims"])
361-
362-
data = np.full(shape, fill_value, order=order)
363-
return asdataarray(cls(**{name: data}, **kwargs))
384+
func = partial(np.full, fill_value=fill_value, order=order)
385+
return cls.shaped(func, shape, **kwargs)

0 commit comments

Comments
 (0)