|
3 | 3 |
|
4 | 4 | # standard library
|
5 | 5 | from dataclasses import Field
|
6 |
| -from functools import wraps |
| 6 | +from functools import partial, wraps |
7 | 7 | from types import MethodType
|
8 | 8 | from typing import Any, Callable, Dict, Optional, Type, TypeVar, overload
|
9 | 9 |
|
|
18 | 18 | # submodules
|
19 | 19 | from .datamodel import DataModel
|
20 | 20 | from .dataoptions import DataOptions
|
21 |
| -from .typing import DataType, Order, Sizes |
| 21 | +from .typing import DataType, Order, Shape, Sizes |
22 | 22 |
|
23 | 23 |
|
24 | 24 | # constants
|
@@ -164,6 +164,53 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
|
164 | 164 |
|
165 | 165 | return MethodType(new, cls)
|
166 | 166 |
|
| 167 | + @overload |
| 168 | + @classmethod |
| 169 | + def shaped( |
| 170 | + cls: Type[DatasetClass[P, TDataset]], |
| 171 | + func: Callable[[Shape], np.ndarray], |
| 172 | + sizes: Sizes, |
| 173 | + **kwargs: Any, |
| 174 | + ) -> TDataset: |
| 175 | + ... |
| 176 | + |
| 177 | + @overload |
| 178 | + @classmethod |
| 179 | + def shaped( |
| 180 | + cls: Type[DataClass[P]], |
| 181 | + func: Callable[[Shape], np.ndarray], |
| 182 | + sizes: Sizes, |
| 183 | + **kwargs: Any, |
| 184 | + ) -> xr.Dataset: |
| 185 | + ... |
| 186 | + |
| 187 | + @classmethod |
| 188 | + def shaped( |
| 189 | + cls: Any, |
| 190 | + func: Callable[[Shape], np.ndarray], |
| 191 | + sizes: Sizes, |
| 192 | + **kwargs: Any, |
| 193 | + ) -> Any: |
| 194 | + """Create a Dataset object from a shaped function. |
| 195 | +
|
| 196 | + Args: |
| 197 | + func: Function to create an array with given shape. |
| 198 | + sizes: Sizes of the new Dataset object. |
| 199 | + kwargs: Args of the Dataset class except for data vars. |
| 200 | +
|
| 201 | + Returns: |
| 202 | + Dataset object created from the shaped function. |
| 203 | +
|
| 204 | + """ |
| 205 | + model = DataModel.from_dataclass(cls) |
| 206 | + data_vars: Dict[str, Any] = {} |
| 207 | + |
| 208 | + for name, item in model.data.items(): |
| 209 | + shape = tuple(sizes[dim] for dim in item.type["dims"]) |
| 210 | + data_vars[name] = func(shape) |
| 211 | + |
| 212 | + return asdataset(cls(**data_vars, **kwargs)) |
| 213 | + |
167 | 214 | @overload
|
168 | 215 | @classmethod
|
169 | 216 | def empty(
|
@@ -203,14 +250,8 @@ def empty(
|
203 | 250 | Dataset object without initializing data vars.
|
204 | 251 |
|
205 | 252 | """
|
206 |
| - model = DataModel.from_dataclass(cls) |
207 |
| - data_vars: Dict[str, Any] = {} |
208 |
| - |
209 |
| - for name, item in model.data.items(): |
210 |
| - shape = tuple(sizes[dim] for dim in item.type["dims"]) |
211 |
| - data_vars[name] = np.empty(shape, order=order) |
212 |
| - |
213 |
| - return asdataset(cls(**data_vars, **kwargs)) |
| 253 | + func = partial(np.empty, order=order) |
| 254 | + return cls.shaped(func, sizes, **kwargs) |
214 | 255 |
|
215 | 256 | @overload
|
216 | 257 | @classmethod
|
@@ -251,14 +292,8 @@ def zeros(
|
251 | 292 | Dataset object whose data vars are filled with zeros.
|
252 | 293 |
|
253 | 294 | """
|
254 |
| - model = DataModel.from_dataclass(cls) |
255 |
| - data_vars: Dict[str, Any] = {} |
256 |
| - |
257 |
| - for name, item in model.data.items(): |
258 |
| - shape = tuple(sizes[dim] for dim in item.type["dims"]) |
259 |
| - data_vars[name] = np.zeros(shape, order=order) |
260 |
| - |
261 |
| - return asdataset(cls(**data_vars, **kwargs)) |
| 295 | + func = partial(np.zeros, order=order) |
| 296 | + return cls.shaped(func, sizes, **kwargs) |
262 | 297 |
|
263 | 298 | @overload
|
264 | 299 | @classmethod
|
@@ -299,14 +334,8 @@ def ones(
|
299 | 334 | Dataset object whose data vars are filled with ones.
|
300 | 335 |
|
301 | 336 | """
|
302 |
| - model = DataModel.from_dataclass(cls) |
303 |
| - data_vars: Dict[str, Any] = {} |
304 |
| - |
305 |
| - for name, item in model.data.items(): |
306 |
| - shape = tuple(sizes[dim] for dim in item.type["dims"]) |
307 |
| - data_vars[name] = np.ones(shape, order=order) |
308 |
| - |
309 |
| - return asdataset(cls(**data_vars, **kwargs)) |
| 337 | + func = partial(np.ones, order=order) |
| 338 | + return cls.shaped(func, sizes, **kwargs) |
310 | 339 |
|
311 | 340 | @overload
|
312 | 341 | @classmethod
|
@@ -351,11 +380,5 @@ def full(
|
351 | 380 | Dataset object whose data vars are filled with given value.
|
352 | 381 |
|
353 | 382 | """
|
354 |
| - model = DataModel.from_dataclass(cls) |
355 |
| - data_vars: Dict[str, Any] = {} |
356 |
| - |
357 |
| - for name, item in model.data.items(): |
358 |
| - shape = tuple(sizes[dim] for dim in item.type["dims"]) |
359 |
| - data_vars[name] = np.full(shape, fill_value, order=order) |
360 |
| - |
361 |
| - return asdataset(cls(**data_vars, **kwargs)) |
| 383 | + func = partial(np.full, fill_value=fill_value, order=order) |
| 384 | + return cls.shaped(func, sizes, **kwargs) |
0 commit comments