Skip to content

Commit c5a47be

Browse files
authored
Merge pull request #124 from astropenguin:astropenguin/issue123
Add classmethod to create a shaped array
2 parents b5f4b04 + 58184bd commit c5a47be

File tree

3 files changed

+114
-67
lines changed

3 files changed

+114
-67
lines changed

pyrightconfig.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
{
22
"typeCheckingMode": "strict",
33
"reportImportCycles": "warning",
4+
"reportUnknownArgumentType": "warning",
45
"reportUnknownMemberType": "warning",
6+
"reportUnknownParameterType": "warning",
57
"reportUnknownVariableType": "warning"
68
}

xarray_dataclasses/dataarray.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field
6-
from functools import wraps
6+
from functools import partial, wraps
77
from types import MethodType
88
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
99

@@ -166,6 +166,52 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
166166

167167
return MethodType(new, cls)
168168

169+
@overload
170+
@classmethod
171+
def shaped(
172+
cls: Type[DataArrayClass[P, TDataArray]],
173+
func: Callable[[Shape], np.ndarray],
174+
shape: Union[Shape, Sizes],
175+
**kwargs: Any,
176+
) -> TDataArray:
177+
...
178+
179+
@overload
180+
@classmethod
181+
def shaped(
182+
cls: Type[DataClass[P]],
183+
func: Callable[[Shape], np.ndarray],
184+
shape: Union[Shape, Sizes],
185+
**kwargs: Any,
186+
) -> xr.DataArray:
187+
...
188+
189+
@classmethod
190+
def shaped(
191+
cls: Any,
192+
func: Callable[[Shape], np.ndarray],
193+
shape: Union[Shape, Sizes],
194+
**kwargs: Any,
195+
) -> Any:
196+
"""Create a DataArray object from a shaped function.
197+
198+
Args:
199+
func: Function to create an array with given shape.
200+
shape: Shape or sizes of the new DataArray object.
201+
kwargs: Args of the DataArray class except for data.
202+
203+
Returns:
204+
DataArray object created from the shaped function.
205+
206+
"""
207+
model = DataModel.from_dataclass(cls)
208+
name, item = next(iter(model.data.items()))
209+
210+
if isinstance(shape, dict):
211+
shape = tuple(shape[dim] for dim in item.type["dims"])
212+
213+
return asdataarray(cls(**{name: func(shape)}, **kwargs))
214+
169215
@overload
170216
@classmethod
171217
def empty(
@@ -205,14 +251,8 @@ def empty(
205251
DataArray object without initializing data.
206252
207253
"""
208-
model = DataModel.from_dataclass(cls)
209-
name, item = next(iter(model.data.items()))
210-
211-
if isinstance(shape, dict):
212-
shape = tuple(shape[dim] for dim in item.type["dims"])
213-
214-
data = np.empty(shape, order=order)
215-
return asdataarray(cls(**{name: data}, **kwargs))
254+
func = partial(np.empty, order=order)
255+
return cls.shaped(func, shape, **kwargs)
216256

217257
@overload
218258
@classmethod
@@ -253,14 +293,8 @@ def zeros(
253293
DataArray object filled with zeros.
254294
255295
"""
256-
model = DataModel.from_dataclass(cls)
257-
name, item = next(iter(model.data.items()))
258-
259-
if isinstance(shape, dict):
260-
shape = tuple(shape[dim] for dim in item.type["dims"])
261-
262-
data = np.zeros(shape, order=order)
263-
return asdataarray(cls(**{name: data}, **kwargs))
296+
func = partial(np.zeros, order=order)
297+
return cls.shaped(func, shape, **kwargs)
264298

265299
@overload
266300
@classmethod
@@ -301,14 +335,8 @@ def ones(
301335
DataArray object filled with ones.
302336
303337
"""
304-
model = DataModel.from_dataclass(cls)
305-
name, item = next(iter(model.data.items()))
306-
307-
if isinstance(shape, dict):
308-
shape = tuple(shape[dim] for dim in item.type["dims"])
309-
310-
data = np.ones(shape, order=order)
311-
return asdataarray(cls(**{name: data}, **kwargs))
338+
func = partial(np.ones, order=order)
339+
return cls.shaped(func, shape, **kwargs)
312340

313341
@overload
314342
@classmethod
@@ -353,11 +381,5 @@ def full(
353381
DataArray object filled with given value.
354382
355383
"""
356-
model = DataModel.from_dataclass(cls)
357-
name, item = next(iter(model.data.items()))
358-
359-
if isinstance(shape, dict):
360-
shape = tuple(shape[dim] for dim in item.type["dims"])
361-
362-
data = np.full(shape, fill_value, order=order)
363-
return asdataarray(cls(**{name: data}, **kwargs))
384+
func = partial(np.full, fill_value=fill_value, order=order)
385+
return cls.shaped(func, shape, **kwargs)

xarray_dataclasses/dataset.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field
6-
from functools import wraps
6+
from functools import partial, wraps
77
from types import MethodType
88
from typing import Any, Callable, Dict, Optional, Type, TypeVar, overload
99

@@ -18,7 +18,7 @@
1818
# submodules
1919
from .datamodel import DataModel
2020
from .dataoptions import DataOptions
21-
from .typing import DataType, Order, Sizes
21+
from .typing import DataType, Order, Shape, Sizes
2222

2323

2424
# constants
@@ -164,6 +164,53 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
164164

165165
return MethodType(new, cls)
166166

167+
@overload
168+
@classmethod
169+
def shaped(
170+
cls: Type[DatasetClass[P, TDataset]],
171+
func: Callable[[Shape], np.ndarray],
172+
sizes: Sizes,
173+
**kwargs: Any,
174+
) -> TDataset:
175+
...
176+
177+
@overload
178+
@classmethod
179+
def shaped(
180+
cls: Type[DataClass[P]],
181+
func: Callable[[Shape], np.ndarray],
182+
sizes: Sizes,
183+
**kwargs: Any,
184+
) -> xr.Dataset:
185+
...
186+
187+
@classmethod
188+
def shaped(
189+
cls: Any,
190+
func: Callable[[Shape], np.ndarray],
191+
sizes: Sizes,
192+
**kwargs: Any,
193+
) -> Any:
194+
"""Create a Dataset object from a shaped function.
195+
196+
Args:
197+
func: Function to create an array with given shape.
198+
sizes: Sizes of the new Dataset object.
199+
kwargs: Args of the Dataset class except for data vars.
200+
201+
Returns:
202+
Dataset object created from the shaped function.
203+
204+
"""
205+
model = DataModel.from_dataclass(cls)
206+
data_vars: Dict[str, Any] = {}
207+
208+
for name, item in model.data.items():
209+
shape = tuple(sizes[dim] for dim in item.type["dims"])
210+
data_vars[name] = func(shape)
211+
212+
return asdataset(cls(**data_vars, **kwargs))
213+
167214
@overload
168215
@classmethod
169216
def empty(
@@ -203,14 +250,8 @@ def empty(
203250
Dataset object without initializing data vars.
204251
205252
"""
206-
model = DataModel.from_dataclass(cls)
207-
data_vars: Dict[str, Any] = {}
208-
209-
for name, item in model.data.items():
210-
shape = tuple(sizes[dim] for dim in item.type["dims"])
211-
data_vars[name] = np.empty(shape, order=order)
212-
213-
return asdataset(cls(**data_vars, **kwargs))
253+
func = partial(np.empty, order=order)
254+
return cls.shaped(func, sizes, **kwargs)
214255

215256
@overload
216257
@classmethod
@@ -251,14 +292,8 @@ def zeros(
251292
Dataset object whose data vars are filled with zeros.
252293
253294
"""
254-
model = DataModel.from_dataclass(cls)
255-
data_vars: Dict[str, Any] = {}
256-
257-
for name, item in model.data.items():
258-
shape = tuple(sizes[dim] for dim in item.type["dims"])
259-
data_vars[name] = np.zeros(shape, order=order)
260-
261-
return asdataset(cls(**data_vars, **kwargs))
295+
func = partial(np.zeros, order=order)
296+
return cls.shaped(func, sizes, **kwargs)
262297

263298
@overload
264299
@classmethod
@@ -299,14 +334,8 @@ def ones(
299334
Dataset object whose data vars are filled with ones.
300335
301336
"""
302-
model = DataModel.from_dataclass(cls)
303-
data_vars: Dict[str, Any] = {}
304-
305-
for name, item in model.data.items():
306-
shape = tuple(sizes[dim] for dim in item.type["dims"])
307-
data_vars[name] = np.ones(shape, order=order)
308-
309-
return asdataset(cls(**data_vars, **kwargs))
337+
func = partial(np.ones, order=order)
338+
return cls.shaped(func, sizes, **kwargs)
310339

311340
@overload
312341
@classmethod
@@ -351,11 +380,5 @@ def full(
351380
Dataset object whose data vars are filled with given value.
352381
353382
"""
354-
model = DataModel.from_dataclass(cls)
355-
data_vars: Dict[str, Any] = {}
356-
357-
for name, item in model.data.items():
358-
shape = tuple(sizes[dim] for dim in item.type["dims"])
359-
data_vars[name] = np.full(shape, fill_value, order=order)
360-
361-
return asdataset(cls(**data_vars, **kwargs))
383+
func = partial(np.full, fill_value=fill_value, order=order)
384+
return cls.shaped(func, sizes, **kwargs)

0 commit comments

Comments
 (0)