Skip to content

Commit 58184bd

Browse files
committed
#123 Add shaped to AsDataset
1 parent 71888c3 commit 58184bd

File tree

1 file changed

+57
-34
lines changed

1 file changed

+57
-34
lines changed

xarray_dataclasses/dataset.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field
6-
from functools import wraps
6+
from functools import partial, wraps
77
from types import MethodType
88
from typing import Any, Callable, Dict, Optional, Type, TypeVar, overload
99

@@ -18,7 +18,7 @@
1818
# submodules
1919
from .datamodel import DataModel
2020
from .dataoptions import DataOptions
21-
from .typing import DataType, Order, Sizes
21+
from .typing import DataType, Order, Shape, Sizes
2222

2323

2424
# constants
@@ -164,6 +164,53 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
164164

165165
return MethodType(new, cls)
166166

167+
@overload
168+
@classmethod
169+
def shaped(
170+
cls: Type[DatasetClass[P, TDataset]],
171+
func: Callable[[Shape], np.ndarray],
172+
sizes: Sizes,
173+
**kwargs: Any,
174+
) -> TDataset:
175+
...
176+
177+
@overload
178+
@classmethod
179+
def shaped(
180+
cls: Type[DataClass[P]],
181+
func: Callable[[Shape], np.ndarray],
182+
sizes: Sizes,
183+
**kwargs: Any,
184+
) -> xr.Dataset:
185+
...
186+
187+
@classmethod
188+
def shaped(
189+
cls: Any,
190+
func: Callable[[Shape], np.ndarray],
191+
sizes: Sizes,
192+
**kwargs: Any,
193+
) -> Any:
194+
"""Create a Dataset object from a shaped function.
195+
196+
Args:
197+
func: Function to create an array with given shape.
198+
sizes: Sizes of the new Dataset object.
199+
kwargs: Args of the Dataset class except for data vars.
200+
201+
Returns:
202+
Dataset object created from the shaped function.
203+
204+
"""
205+
model = DataModel.from_dataclass(cls)
206+
data_vars: Dict[str, Any] = {}
207+
208+
for name, item in model.data.items():
209+
shape = tuple(sizes[dim] for dim in item.type["dims"])
210+
data_vars[name] = func(shape)
211+
212+
return asdataset(cls(**data_vars, **kwargs))
213+
167214
@overload
168215
@classmethod
169216
def empty(
@@ -203,14 +250,8 @@ def empty(
203250
Dataset object without initializing data vars.
204251
205252
"""
206-
model = DataModel.from_dataclass(cls)
207-
data_vars: Dict[str, Any] = {}
208-
209-
for name, item in model.data.items():
210-
shape = tuple(sizes[dim] for dim in item.type["dims"])
211-
data_vars[name] = np.empty(shape, order=order)
212-
213-
return asdataset(cls(**data_vars, **kwargs))
253+
func = partial(np.empty, order=order)
254+
return cls.shaped(func, sizes, **kwargs)
214255

215256
@overload
216257
@classmethod
@@ -251,14 +292,8 @@ def zeros(
251292
Dataset object whose data vars are filled with zeros.
252293
253294
"""
254-
model = DataModel.from_dataclass(cls)
255-
data_vars: Dict[str, Any] = {}
256-
257-
for name, item in model.data.items():
258-
shape = tuple(sizes[dim] for dim in item.type["dims"])
259-
data_vars[name] = np.zeros(shape, order=order)
260-
261-
return asdataset(cls(**data_vars, **kwargs))
295+
func = partial(np.zeros, order=order)
296+
return cls.shaped(func, sizes, **kwargs)
262297

263298
@overload
264299
@classmethod
@@ -299,14 +334,8 @@ def ones(
299334
Dataset object whose data vars are filled with ones.
300335
301336
"""
302-
model = DataModel.from_dataclass(cls)
303-
data_vars: Dict[str, Any] = {}
304-
305-
for name, item in model.data.items():
306-
shape = tuple(sizes[dim] for dim in item.type["dims"])
307-
data_vars[name] = np.ones(shape, order=order)
308-
309-
return asdataset(cls(**data_vars, **kwargs))
337+
func = partial(np.ones, order=order)
338+
return cls.shaped(func, sizes, **kwargs)
310339

311340
@overload
312341
@classmethod
@@ -351,11 +380,5 @@ def full(
351380
Dataset object whose data vars are filled with given value.
352381
353382
"""
354-
model = DataModel.from_dataclass(cls)
355-
data_vars: Dict[str, Any] = {}
356-
357-
for name, item in model.data.items():
358-
shape = tuple(sizes[dim] for dim in item.type["dims"])
359-
data_vars[name] = np.full(shape, fill_value, order=order)
360-
361-
return asdataset(cls(**data_vars, **kwargs))
383+
func = partial(np.full, fill_value=fill_value, order=order)
384+
return cls.shaped(func, sizes, **kwargs)

0 commit comments

Comments
 (0)