|
3 | 3 |
|
4 | 4 | # standard library
|
5 | 5 | from dataclasses import Field
|
6 |
| -from functools import wraps |
| 6 | +from functools import partial, wraps |
7 | 7 | from types import MethodType
|
8 | 8 | from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
|
9 | 9 |
|
@@ -166,6 +166,52 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
|
166 | 166 |
|
167 | 167 | return MethodType(new, cls)
|
168 | 168 |
|
| 169 | + @overload |
| 170 | + @classmethod |
| 171 | + def shaped( |
| 172 | + cls: Type[DataArrayClass[P, TDataArray]], |
| 173 | + func: Callable[[Shape], np.ndarray], |
| 174 | + shape: Union[Shape, Sizes], |
| 175 | + **kwargs: Any, |
| 176 | + ) -> TDataArray: |
| 177 | + ... |
| 178 | + |
| 179 | + @overload |
| 180 | + @classmethod |
| 181 | + def shaped( |
| 182 | + cls: Type[DataClass[P]], |
| 183 | + func: Callable[[Shape], np.ndarray], |
| 184 | + shape: Union[Shape, Sizes], |
| 185 | + **kwargs: Any, |
| 186 | + ) -> xr.DataArray: |
| 187 | + ... |
| 188 | + |
| 189 | + @classmethod |
| 190 | + def shaped( |
| 191 | + cls: Any, |
| 192 | + func: Callable[[Shape], np.ndarray], |
| 193 | + shape: Union[Shape, Sizes], |
| 194 | + **kwargs: Any, |
| 195 | + ) -> Any: |
| 196 | + """Create a DataArray object from a shaped function. |
| 197 | +
|
| 198 | + Args: |
| 199 | + func: Function to create an array with given shape. |
| 200 | + shape: Shape or sizes of the new DataArray object. |
| 201 | + kwargs: Args of the DataArray class except for data. |
| 202 | +
|
| 203 | + Returns: |
| 204 | + DataArray object created from the shaped function. |
| 205 | +
|
| 206 | + """ |
| 207 | + model = DataModel.from_dataclass(cls) |
| 208 | + name, item = next(iter(model.data.items())) |
| 209 | + |
| 210 | + if isinstance(shape, dict): |
| 211 | + shape = tuple(shape[dim] for dim in item.type["dims"]) |
| 212 | + |
| 213 | + return asdataarray(cls(**{name: func(shape)}, **kwargs)) |
| 214 | + |
169 | 215 | @overload
|
170 | 216 | @classmethod
|
171 | 217 | def empty(
|
@@ -205,14 +251,8 @@ def empty(
|
205 | 251 | DataArray object without initializing data.
|
206 | 252 |
|
207 | 253 | """
|
208 |
| - model = DataModel.from_dataclass(cls) |
209 |
| - name, item = next(iter(model.data.items())) |
210 |
| - |
211 |
| - if isinstance(shape, dict): |
212 |
| - shape = tuple(shape[dim] for dim in item.type["dims"]) |
213 |
| - |
214 |
| - data = np.empty(shape, order=order) |
215 |
| - return asdataarray(cls(**{name: data}, **kwargs)) |
| 254 | + func = partial(np.empty, order=order) |
| 255 | + return cls.shaped(func, shape, **kwargs) |
216 | 256 |
|
217 | 257 | @overload
|
218 | 258 | @classmethod
|
@@ -253,14 +293,8 @@ def zeros(
|
253 | 293 | DataArray object filled with zeros.
|
254 | 294 |
|
255 | 295 | """
|
256 |
| - model = DataModel.from_dataclass(cls) |
257 |
| - name, item = next(iter(model.data.items())) |
258 |
| - |
259 |
| - if isinstance(shape, dict): |
260 |
| - shape = tuple(shape[dim] for dim in item.type["dims"]) |
261 |
| - |
262 |
| - data = np.zeros(shape, order=order) |
263 |
| - return asdataarray(cls(**{name: data}, **kwargs)) |
| 296 | + func = partial(np.zeros, order=order) |
| 297 | + return cls.shaped(func, shape, **kwargs) |
264 | 298 |
|
265 | 299 | @overload
|
266 | 300 | @classmethod
|
@@ -301,14 +335,8 @@ def ones(
|
301 | 335 | DataArray object filled with ones.
|
302 | 336 |
|
303 | 337 | """
|
304 |
| - model = DataModel.from_dataclass(cls) |
305 |
| - name, item = next(iter(model.data.items())) |
306 |
| - |
307 |
| - if isinstance(shape, dict): |
308 |
| - shape = tuple(shape[dim] for dim in item.type["dims"]) |
309 |
| - |
310 |
| - data = np.ones(shape, order=order) |
311 |
| - return asdataarray(cls(**{name: data}, **kwargs)) |
| 338 | + func = partial(np.ones, order=order) |
| 339 | + return cls.shaped(func, shape, **kwargs) |
312 | 340 |
|
313 | 341 | @overload
|
314 | 342 | @classmethod
|
@@ -353,11 +381,5 @@ def full(
|
353 | 381 | DataArray object filled with given value.
|
354 | 382 |
|
355 | 383 | """
|
356 |
| - model = DataModel.from_dataclass(cls) |
357 |
| - name, item = next(iter(model.data.items())) |
358 |
| - |
359 |
| - if isinstance(shape, dict): |
360 |
| - shape = tuple(shape[dim] for dim in item.type["dims"]) |
361 |
| - |
362 |
| - data = np.full(shape, fill_value, order=order) |
363 |
| - return asdataarray(cls(**{name: data}, **kwargs)) |
| 384 | + func = partial(np.full, fill_value=fill_value, order=order) |
| 385 | + return cls.shaped(func, shape, **kwargs) |
0 commit comments