Skip to content

Commit 68c2591

Browse files
authored
Merge pull request #113 from astropenguin/astropenguin/issue112
Add data options
2 parents 8b68233 + a4423d2 commit 68c2591

File tree

9 files changed

+107
-49
lines changed

9 files changed

+107
-49
lines changed

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,15 @@ class Image(AsDataArray):
291291
y: Coordof[YAxis] = 0
292292
```
293293

294-
### Custom DataArray and Dataset factories
294+
### Options for DataArray and Dataset creation
295+
296+
For customization, users can add a special class attribute, `__dataoptions__`, to a DataArray or Dataset class.
297+
A custom factory for DataArray or Dataset creation is only supported in the current implementation.
295298

296-
For customization, users can use a function or a class to create an initial DataArray or Dataset object by specifying a special class attribute, `__dataarray_factory__` or `__dataset_factory__`, respectively.
297299

298300
```python
299301
import xarray as xr
302+
from xarray_dataclasses import DataOptions
300303

301304

302305
class Custom(xr.DataArray):
@@ -308,19 +311,23 @@ class Custom(xr.DataArray):
308311
print("Custom method!")
309312

310313

314+
dataoptions = DataOptions(Custom)
315+
316+
311317
@dataclass
312318
class Image(AsDataArray):
313319
"""Specs for a monochromatic image."""
314320

321+
__dataoptions__ = dataoptions
322+
315323
data: Data[tuple[X, Y], float]
316324
x: Coord[X, int] = 0
317325
y: Coord[Y, int] = 0
318-
__dataarray_factory__ = Custom
319326

320327

321328
image = Image.ones([3, 3])
322-
isinstance(image, Custom) # True
323-
image.custom_method() # Custom method!
329+
isinstance(image, Custom) # True
330+
image.custom_method() # Custom method!
324331
```
325332

326333
### DataArray and Dataset creation without shorthands

tests/test_dataarray.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# standard library
22
from dataclasses import dataclass
3-
from typing import Any, Tuple
3+
from typing import Tuple
44

55

66
# third-party packages
@@ -11,6 +11,7 @@
1111

1212
# submodules
1313
from xarray_dataclasses.dataarray import AsDataArray
14+
from xarray_dataclasses.dataoptions import DataOptions
1415
from xarray_dataclasses.typing import Attr, Coord, Data, Name
1516

1617

@@ -31,19 +32,21 @@ class Custom(xr.DataArray):
3132
__slots__ = ()
3233

3334

35+
dataoptions = DataOptions(Custom)
36+
37+
3438
@dataclass
3539
class Image(AsDataArray):
3640
"""Specs for a monochromatic image."""
3741

42+
__dataoptions__ = dataoptions
43+
3844
data: Data[Tuple[X, Y], float]
3945
x: Coord[X, int] = 0
4046
y: Coord[Y, int] = 0
4147
units: Attr[str] = "cd / m^2"
4248
name: Name[str] = "luminance"
4349

44-
def __dataarray_factory__(self, data: Any = None) -> Custom:
45-
return Custom(data)
46-
4750

4851
# test datasets
4952
created = Image.ones(SHAPE)

tests/test_dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# standard library
22
from dataclasses import dataclass
3-
from typing import Any, Tuple
3+
from typing import Tuple
44

55

66
# third-party packages
@@ -12,6 +12,7 @@
1212
# submodules
1313
from xarray_dataclasses.dataarray import AsDataArray
1414
from xarray_dataclasses.dataset import AsDataset
15+
from xarray_dataclasses.dataoptions import DataOptions
1516
from xarray_dataclasses.typing import Attr, Coord, Data
1617

1718
# constants
@@ -29,6 +30,9 @@ class Custom(xr.Dataset):
2930
__slots__ = ()
3031

3132

33+
dataoptions = DataOptions(Custom)
34+
35+
3236
@dataclass
3337
class Image(AsDataArray):
3438
"""Specs for a monochromatic image."""
@@ -40,16 +44,15 @@ class Image(AsDataArray):
4044
class ColorImage(AsDataset):
4145
"""Specs for a color image."""
4246

47+
__dataoptions__ = dataoptions
48+
4349
red: Data[Tuple[X, Y], float]
4450
green: Data[Tuple[X, Y], float]
4551
blue: Data[Tuple[X, Y], float]
4652
x: Coord[X, int] = 0
4753
y: Coord[Y, int] = 0
4854
units: Attr[str] = "cd / m^2"
4955

50-
def __dataset_factory__(self, data_vars: Any = None) -> Custom:
51-
return Custom(data_vars)
52-
5356

5457
# test datasets
5558
created = ColorImage.new(

xarray_dataclasses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _make_field_generic():
2020
from . import dataset
2121
from . import deprecated
2222
from . import datamodel
23+
from . import dataoptions
2324
from . import typing
2425

2526

@@ -28,6 +29,7 @@ def _make_field_generic():
2829
from .dataset import *
2930
from .deprecated import *
3031
from .datamodel import *
32+
from .dataoptions import *
3133
from .typing import *
3234

3335

xarray_dataclasses/dataarray.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import Field
66
from functools import wraps
77
from types import MethodType
8-
from typing import Any, Callable, Dict, Type, TypeVar, Union, overload
8+
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
99

1010

1111
# dependencies
@@ -16,8 +16,13 @@
1616

1717

1818
# submodules
19-
from .datamodel import DataModel, Reference
20-
from .typing import Order, Shape, Sizes
19+
from .datamodel import DataModel
20+
from .dataoptions import DataOptions
21+
from .typing import DataType, Order, Shape, Sizes
22+
23+
24+
# constants
25+
DEFAULT_OPTIONS = DataOptions(xr.DataArray)
2126

2227

2328
# type hints
@@ -38,7 +43,7 @@ class DataArrayClass(Protocol[P, TDataArray_]):
3843

3944
__init__: Callable[P, None]
4045
__dataclass_fields__: Dict[str, Field[Any]]
41-
__dataarray_factory__: Callable[..., TDataArray_]
46+
__dataoptions__: DataOptions[TDataArray_]
4247

4348

4449
# custom classproperty
@@ -65,44 +70,50 @@ def __get__(
6570
@overload
6671
def asdataarray(
6772
dataclass: DataArrayClass[Any, TDataArray],
68-
reference: Reference = None,
69-
dataarray_factory: Any = xr.DataArray,
73+
reference: Optional[DataType] = None,
74+
dataoptions: Any = DEFAULT_OPTIONS,
7075
) -> TDataArray:
7176
...
7277

7378

7479
@overload
7580
def asdataarray(
7681
dataclass: DataClass[Any],
77-
reference: Reference = None,
78-
dataarray_factory: Callable[..., TDataArray] = xr.DataArray,
82+
reference: Optional[DataType] = None,
83+
dataoptions: DataOptions[TDataArray] = DEFAULT_OPTIONS,
7984
) -> TDataArray:
8085
...
8186

8287

8388
def asdataarray(
8489
dataclass: Any,
8590
reference: Any = None,
86-
dataarray_factory: Any = xr.DataArray,
91+
dataoptions: Any = DEFAULT_OPTIONS,
8792
) -> Any:
8893
"""Create a DataArray object from a dataclass object.
8994
9095
Args:
9196
dataclass: Dataclass object that defines typed DataArray.
9297
reference: DataArray or Dataset object as a reference of shape.
93-
dataset_factory: Factory function of DataArray.
98+
dataoptions: Options for DataArray creation.
9499
95100
Returns:
96101
DataArray object created from the dataclass object.
97102
98103
"""
99104
try:
100-
dataarray_factory = dataclass.__dataarray_factory__
105+
# for backward compatibility (deprecated in v1.0.0)
106+
dataoptions = DataOptions(dataclass.__dataarray_factory__)
107+
except AttributeError:
108+
pass
109+
110+
try:
111+
dataoptions = dataclass.__dataoptions__
101112
except AttributeError:
102113
pass
103114

104115
model = DataModel.from_dataclass(dataclass)
105-
dataarray = dataarray_factory(model.data[0](reference))
116+
dataarray = dataoptions.factory(model.data[0](reference))
106117

107118
for coord in model.coord:
108119
dataarray.coords.update({coord.name: coord(dataarray)})
@@ -119,9 +130,7 @@ def asdataarray(
119130
class AsDataArray:
120131
"""Mix-in class that provides shorthand methods."""
121132

122-
def __dataarray_factory__(self, data: Any = None) -> xr.DataArray:
123-
"""Default DataArray factory (xarray.DataArray)."""
124-
return xr.DataArray(data)
133+
__dataoptions__ = DEFAULT_OPTIONS
125134

126135
@classproperty
127136
def new(cls: Type[DataArrayClass[P, TDataArray]]) -> Callable[P, TDataArray]:

xarray_dataclasses/datamodel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field, dataclass, field, is_dataclass
6-
from typing import Any, List, Optional, Type, Union, cast
6+
from typing import Any, List, Optional, Type, cast
77

88

99
# dependencies
@@ -17,6 +17,7 @@
1717
from .typing import (
1818
ArrayLike,
1919
DataClass,
20+
DataType,
2021
Dims,
2122
Dtype,
2223
FieldType,
@@ -28,8 +29,7 @@
2829

2930

3031
# type hints
31-
DataType = TypedDict("DataType", dims=Dims, dtype=Dtype)
32-
Reference = Union[xr.DataArray, xr.Dataset, None]
32+
DimsDtype = TypedDict("DimsDtype", dims=Dims, dtype=Dtype)
3333

3434

3535
# field models
@@ -43,13 +43,13 @@ class Data:
4343
value: Any
4444
"""Value assigned to the field."""
4545

46-
type: DataType
46+
type: DimsDtype
4747
"""Type (dims and dtype) of the field."""
4848

4949
factory: Optional[Type[DataClass]] = None
5050
"""Factory dataclass to create a DataArray object."""
5151

52-
def __call__(self, reference: Reference = None) -> xr.DataArray:
52+
def __call__(self, reference: Optional[DataType] = None) -> xr.DataArray:
5353
"""Create a DataArray object from the value and a reference."""
5454
from .dataarray import asdataarray
5555

@@ -173,7 +173,7 @@ def typedarray(
173173
data: Any,
174174
dims: Dims,
175175
dtype: Dtype,
176-
reference: Reference = None,
176+
reference: Optional[DataType] = None,
177177
) -> xr.DataArray:
178178
"""Create a DataArray object with given dims and dtype.
179179

xarray_dataclasses/dataoptions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
__all__ = ["DataOptions"]
2+
3+
4+
# standard library
5+
from dataclasses import dataclass
6+
from typing import Callable, Generic, TypeVar
7+
8+
9+
# submodules
10+
from .typing import DataType
11+
12+
13+
# type hints
14+
TDataType = TypeVar("TDataType", bound=DataType)
15+
16+
17+
# dataclasses
18+
@dataclass(frozen=True)
19+
class DataOptions(Generic[TDataType]):
20+
"""Options for DataArray or Dataset creation."""
21+
22+
factory: Callable[..., TDataType]
23+
"""Factory function for DataArray or Dataset."""

0 commit comments

Comments
 (0)