Skip to content

Commit 9e0516b

Browse files
author
Joseph Hamman
committed
first pass at attrs and coords schemas
1 parent a6b8184 commit 9e0516b

File tree

6 files changed

+181
-16
lines changed

6 files changed

+181
-16
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ The basic usage is as follows:
3838
```python
3939
import numpy as np
4040
import xarray as xr
41-
from xarray_schema import DataArraySchema, DatasetSchema
41+
from xarray_schema import DataArraySchema, DatasetSchema, CoordsSchema
4242

4343
da = xr.DataArray(np.ones(4, dtype='i4'), dims=['x'], name='foo')
4444

@@ -64,7 +64,9 @@ from xarray_schema.components import (
6464
ShapeSchema,
6565
NameSchema,
6666
ChunksSchema,
67-
ArrayTypeSchema
67+
ArrayTypeSchema,
68+
AttrSchema,
69+
AttrsSchema
6870
)
6971

7072
# example constructions

tests/test_core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from xarray_schema.base import SchemaError
77
from xarray_schema.components import (
88
ArrayTypeSchema,
9+
AttrSchema,
10+
AttrsSchema,
911
ChunksSchema,
1012
DimsSchema,
1113
DTypeSchema,
@@ -52,6 +54,18 @@ def ds():
5254
[(((2, 2), (10,)), ('x', 'y'), (4, 10))],
5355
{'x': 2, 'y': -1},
5456
),
57+
(
58+
AttrsSchema,
59+
{'foo': AttrSchema(value='bar')},
60+
[{'foo': 'bar'}],
61+
{'foo': {'type': None, 'value': 'bar'}},
62+
),
63+
(
64+
AttrsSchema,
65+
{'foo': AttrSchema(value=1)},
66+
[{'foo': 1}],
67+
{'foo': {'type': None, 'value': 1}},
68+
),
5569
],
5670
)
5771
def test_component_schema(component, schema_args, validate, json):
@@ -65,6 +79,21 @@ def test_component_schema(component, schema_args, validate, json):
6579
assert isinstance(schema.to_json(), str)
6680

6781

82+
@pytest.mark.parametrize(
83+
'type, value, validate, json',
84+
[
85+
(str, None, 'foo', {'type': str, 'value': None}),
86+
(None, 'foo', 'foo', {'type': None, 'value': 'foo'}),
87+
(str, 'foo', 'foo', {'type': str, 'value': 'foo'}),
88+
],
89+
)
90+
def test_attr_schema(type, value, validate, json):
91+
schema = AttrSchema(type=type, value=value)
92+
schema.validate(validate)
93+
assert schema.json == json
94+
# assert isinstance(schema.to_json(), str)
95+
96+
6897
@pytest.mark.parametrize(
6998
'component, schema_args, validate, match',
7099
[

xarray_schema/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
NameSchema,
99
ShapeSchema,
1010
)
11-
from .dataarray import DataArraySchema # noqa: F401
11+
from .dataarray import CoordsSchema, DataArraySchema # noqa: F401
1212
from .dataset import DatasetSchema # noqa: F401
1313

1414
try:

xarray_schema/components.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections.abc import Iterable
2-
from typing import Any, Dict, Hashable, Optional, Tuple, Union
2+
from typing import Any, Dict, Hashable, Mapping, Optional, Tuple, Union
33

44
import numpy as np
5+
import numpy.typing as npt
56

67
from .base import BaseSchema, SchemaError
78
from .types import ChunksT, DimsT, ShapeT
@@ -11,13 +12,13 @@ class DTypeSchema(BaseSchema):
1112

1213
_json_schema = {'type': 'string'}
1314

14-
def __init__(self, dtype: np.typing.DTypeLike) -> None:
15+
def __init__(self, dtype: npt.DTypeLike) -> None:
1516
if dtype in [np.floating, np.integer, np.signedinteger, np.unsignedinteger, np.generic]:
1617
self.dtype = dtype
1718
else:
1819
self.dtype = np.dtype(dtype)
1920

20-
def validate(self, dtype: np.typing.DTypeLike) -> None:
21+
def validate(self, dtype: npt.DTypeLike) -> None:
2122
'''Validate dtype
2223
2324
Parameters
@@ -201,3 +202,67 @@ def validate(self, array: Any) -> None:
201202
@property
202203
def json(self) -> str:
203204
return str(self.array_type)
205+
206+
207+
class AttrSchema(BaseSchema):
208+
209+
_json_schema = {'type': 'object'} # TODO: add type/value here
210+
211+
def __init__(self, type: Any = None, value: Any = None):
212+
self.type = type
213+
self.value = value
214+
215+
def validate(self, attr: Any):
216+
217+
if self.type is not None:
218+
if not isinstance(attr, self.type):
219+
SchemaError(f'attrs {attr} is not of type {self.type}')
220+
221+
if self.value is not None:
222+
if self.value is not None and self.value != attr:
223+
raise SchemaError(f'name {attr} != {self.value}')
224+
225+
@property
226+
def json(self) -> str:
227+
return {'type': self.type, 'value': self.value}
228+
229+
230+
class AttrsSchema(BaseSchema):
231+
232+
_json_schema = {'type': 'string'}
233+
234+
def __init__(
235+
self, attrs: Mapping, require_all_keys: bool = True, allow_extra_keys: bool = True
236+
) -> None:
237+
self.attrs = attrs
238+
self.require_all_keys = require_all_keys
239+
self.allow_extra_keys = allow_extra_keys
240+
241+
def validate(self, attrs: Any) -> None:
242+
'''Validate attrs
243+
244+
Parameters
245+
----------
246+
attrs : dict_like
247+
attrs of the DataArray. `None` may be used as a wildcard value.
248+
'''
249+
250+
if self.require_all_keys:
251+
missing_keys = set(self.attrs) - set(attrs)
252+
if missing_keys:
253+
raise SchemaError(f'attrs has missing keys: {missing_keys}')
254+
255+
if not self.allow_extra_keys:
256+
extra_keys = set(attrs) - set(self.attrs)
257+
if extra_keys:
258+
raise SchemaError(f'attrs has extra keys: {extra_keys}')
259+
260+
for key, attr_schema in self.attrs.items():
261+
if key not in attrs:
262+
raise SchemaError(f'key {key} not in attrs')
263+
else:
264+
attr_schema.validate(attrs[key])
265+
266+
@property
267+
def json(self) -> dict:
268+
return {k: v.json for k, v in self.attrs.items()}

xarray_schema/dataarray.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Any, Callable, Dict, List, Union
1+
from typing import Any, Callable, Dict, List, Mapping, Union
22

33
import numpy as np
44
import xarray as xr
55

6-
from .base import BaseSchema
6+
from .base import BaseSchema, SchemaError
77
from .components import (
88
ArrayTypeSchema,
9+
AttrsSchema,
910
ChunksSchema,
1011
DimsSchema,
1112
DTypeSchema,
@@ -43,7 +44,9 @@ class DataArraySchema(BaseSchema):
4344
_shape: Union[ShapeSchema, None]
4445
_dims: Union[DimsSchema, None]
4546
_name: Union[NameSchema, None]
47+
_coords: Union[Any, None]
4648
_chunks: Union[ChunksSchema, None]
49+
_attrs: Union[AttrsSchema, None]
4750
_array_type: Union[ArrayTypeSchema, None]
4851

4952
def __init__(
@@ -55,7 +58,7 @@ def __init__(
5558
coords: Dict[str, Any] = None,
5659
chunks: Union[ChunksT, ChunksSchema] = None,
5760
array_type: Any = None,
58-
attrs: Dict[str, Any] = None,
61+
attrs: Mapping[str, Any] = None,
5962
checks: List[Callable] = None,
6063
) -> None:
6164

@@ -135,6 +138,28 @@ def array_type(self, value):
135138
else:
136139
self._array_type = ArrayTypeSchema(value)
137140

141+
@property
142+
def attrs(self) -> AttrsSchema:
143+
return self._attrs
144+
145+
@attrs.setter
146+
def attrs(self, value):
147+
if value is None or isinstance(value, AttrsSchema):
148+
self._attrs = value
149+
else:
150+
self._attrs = AttrsSchema(value)
151+
152+
@property
153+
def coords(self) -> Mapping:
154+
return self._coords
155+
156+
@coords.setter
157+
def coords(self, value):
158+
if value is None or isinstance(value, CoordsSchema):
159+
self._coords = value
160+
else:
161+
self._coords = CoordsSchema(value)
162+
138163
@property
139164
def checks(self) -> List[Callable]:
140165
return self._checks
@@ -180,14 +205,14 @@ def validate(self, da: xr.DataArray) -> None:
180205
if self.shape is not None:
181206
self.shape.validate(da.shape)
182207

183-
if self.coords is not None: # pragma: no cover
184-
raise NotImplementedError('coords schema not implemented yet')
208+
if self.coords is not None:
209+
self.coords.validate(da.coords)
185210

186211
if self.chunks is not None:
187212
self.chunks.validate(da.chunks, da.dims, da.shape)
188213

189-
if self.attrs: # pragma: no cover
190-
raise NotImplementedError('attrs schema not implemented yet')
214+
if self.attrs:
215+
self.attrs.validate(da.attrs)
191216

192217
if self.array_type is not None:
193218
self.array_type.validate(da.data)
@@ -204,3 +229,47 @@ def json(self) -> dict:
204229
except AttributeError:
205230
pass
206231
return obj
232+
233+
234+
class CoordsSchema(BaseSchema):
235+
236+
_json_schema = {'type': 'string'}
237+
238+
def __init__(
239+
self,
240+
coords: Mapping[str, Any],
241+
require_all_keys: bool = True,
242+
allow_extra_keys: bool = True,
243+
) -> None:
244+
self.coords = coords
245+
self.require_all_keys = require_all_keys
246+
self.allow_extra_keys = allow_extra_keys
247+
248+
def validate(self, coords: Any) -> None:
249+
'''Validate coords
250+
251+
Parameters
252+
----------
253+
coords : dict_like
254+
coords of the DataArray. `None` may be used as a wildcard value.
255+
'''
256+
257+
if self.require_all_keys:
258+
missing_keys = set(self.coords) - set(coords)
259+
if missing_keys:
260+
raise SchemaError(f'coords has missing keys: {missing_keys}')
261+
262+
if not self.allow_extra_keys:
263+
extra_keys = set(coords) - set(self.coords)
264+
if extra_keys:
265+
raise SchemaError(f'coords has extra keys: {extra_keys}')
266+
267+
for key, da_schema in self.coords.items():
268+
if key not in coords:
269+
raise SchemaError(f'key {key} not in coords')
270+
else:
271+
da_schema.validate(coords[key])
272+
273+
@property
274+
def json(self) -> dict:
275+
return {k: v.json for k, v in self.attrs.items()}

xarray_schema/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ def validate(self, ds: xr.Dataset) -> None:
5959
if self.coords is not None: # pragma: no cover
6060
raise NotImplementedError('coords schema not implemented yet')
6161

62-
if self.attrs: # pragma: no cover
63-
raise NotImplementedError('attrs schema not implemented yet')
62+
if self.attrs:
63+
self.attrs.validate(ds.attrs)
6464

6565
if self.checks:
6666
for check in self.checks:
6767
check(ds)
6868

6969
@property
7070
def json(self):
71-
obj = {'data_vars': {}} # TODO: add when , 'coords': {}, 'attrs': {}}
71+
obj = {'data_vars': {}, 'attrs': self.attrs.json if self.attrs is not None else None}
7272
if self.data_vars:
7373
for key, var in self.data_vars.items():
7474
obj['data_vars'][key] = var.json

0 commit comments

Comments
 (0)