Skip to content

Commit 165e81c

Browse files
author
Joe Hamman
authored
Merge pull request #7 from carbonplan/feature/component-schemas
Refactor / add component schemas
2 parents 2b4b4de + 91b2967 commit 165e81c

File tree

11 files changed

+703
-357
lines changed

11 files changed

+703
-357
lines changed

.github/workflows/main.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@ on:
99
- cron: "0 0 * * *"
1010

1111
jobs:
12-
lint:
13-
runs-on: ubuntu-latest
14-
steps:
15-
- uses: actions/[email protected]
16-
- uses: actions/[email protected]
17-
- uses: pre-commit/[email protected]
18-
1912
test:
2013
name: ${{ matrix.python-version }}-build
2114
runs-on: ubuntu-latest

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Schema validation for Xarray
99

1010
[![CI](https://github.com/carbonplan/ndpyramid/actions/workflows/main.yaml/badge.svg)](https://github.com/carbonplan/xarray-schema/actions/workflows/main.yaml)
11+
[![codecov](https://codecov.io/gh/carbonplan/xarray-schema/branch/main/graph/badge.svg?token=EI729ZRFK0)](https://codecov.io/gh/carbonplan/xarray-schema)
1112
![MIT License](https://badgen.net/badge/license/MIT/blue)
1213

1314
# installation
@@ -48,8 +49,7 @@ schema_ds.validate(da.to_dataset())
4849

4950
This is a very early prototype of a library. Some key things are missing:
5051

51-
1. Validation of `coords`, `chunks`, and `attrs`. None of these are implemented yet.
52-
1. Class-based schema's for parts of the Xarray data model. Most validations are currently made as direct comparisons (`da.name == self.name`) but a more robust approach is possible that leverages classes for each component of the data model. We're already handling some special cases using `None` as a sentinel value to allow for wildcard-like behavior in places (i.e. `dims` and `shape`)
52+
1. Validation of `coords` and `attrs`. These are implemented yet.
5353
1. Exceptions: Pandera accumulates schema exceptions and reports them all at once. Currently, we are a eagerly raising `SchemaErrors` when the are found.
5454

5555
## license

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ select = B,C,E,F,W,T4,B9
88

99
[isort]
1010
known_first_party=xarray_schema
11-
known_third_party=dask,numpy,pkg_resources,pytest,setuptools,xarray
11+
known_third_party=numpy,pkg_resources,pytest,setuptools,xarray
1212
multi_line_output=3
1313
include_trailing_comma=True
1414
force_grid_wrap=0

tests/test_core.py

Lines changed: 172 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -3,169 +3,209 @@
33
import xarray as xr
44

55
from xarray_schema import DataArraySchema, DatasetSchema
6-
from xarray_schema.core import SchemaError
6+
from xarray_schema.base import SchemaError
7+
from xarray_schema.components import (
8+
ArrayTypeSchema,
9+
ChunksSchema,
10+
DimsSchema,
11+
DTypeSchema,
12+
NameSchema,
13+
ShapeSchema,
14+
)
15+
16+
17+
@pytest.fixture
18+
def ds():
19+
ds = xr.Dataset(
20+
{
21+
'x': xr.DataArray(np.arange(4) - 2, dims='x'),
22+
'foo': xr.DataArray(np.ones(4, dtype='i4'), dims='x'),
23+
'bar': xr.DataArray(np.arange(8, dtype=np.float32).reshape(4, 2), dims=('x', 'y')),
24+
}
25+
)
26+
return ds
27+
28+
29+
@pytest.mark.parametrize(
30+
'component, schema_args, validate, json',
31+
[
32+
(DTypeSchema, np.integer, ['i4', 'int', np.int32], 'integer'),
33+
(DTypeSchema, np.int64, ['i8', np.int64], '<i8'),
34+
(DTypeSchema, '<i8', ['i8', np.int64], '<i8'),
35+
(DimsSchema, ('foo', None), [('foo', 'bar'), ('foo', 'baz')], ['foo', None]),
36+
(DimsSchema, ('foo', 'bar'), [('foo', 'bar')], ['foo', 'bar']),
37+
(ShapeSchema, (1, 2, None), [(1, 2, 3), (1, 2, 5)], [1, 2, None]),
38+
(ShapeSchema, (1, 2, 3), [(1, 2, 3)], [1, 2, 3]),
39+
(NameSchema, 'foo', ['foo'], 'foo'),
40+
(ArrayTypeSchema, np.ndarray, [np.array([1, 2, 3])], "<class 'numpy.ndarray'>"),
41+
# schema_args for ChunksSchema include [chunks, dims, shape]
42+
(ChunksSchema, True, [(((1, 1),), ('x',), (2,))], True),
43+
(ChunksSchema, {'x': 2}, [(((2, 2),), ('x',), (4,))], {'x': 2}),
44+
(ChunksSchema, {'x': (2, 2)}, [(((2, 2),), ('x',), (4,))], {'x': [2, 2]}),
45+
(ChunksSchema, {'x': [2, 2]}, [(((2, 2),), ('x',), (4,))], {'x': [2, 2]}),
46+
(ChunksSchema, {'x': 4}, [(((4,),), ('x',), (4,))], {'x': 4}),
47+
(ChunksSchema, {'x': -1}, [(((4,),), ('x',), (4,))], {'x': -1}),
48+
(ChunksSchema, {'x': (1, 2, 1)}, [(((1, 2, 1),), ('x',), (4,))], {'x': [1, 2, 1]}),
49+
(
50+
ChunksSchema,
51+
{'x': 2, 'y': -1},
52+
[(((2, 2), (10,)), ('x', 'y'), (4, 10))],
53+
{'x': 2, 'y': -1},
54+
),
55+
],
56+
)
57+
def test_component_schema(component, schema_args, validate, json):
58+
schema = component(schema_args)
59+
for v in validate:
60+
if component in [ChunksSchema]: # special case construction
61+
schema.validate(*v)
62+
else:
63+
schema.validate(v)
64+
assert schema.json == json
65+
assert isinstance(schema.to_json(), str)
66+
67+
68+
@pytest.mark.parametrize(
69+
'component, schema_args, validate, match',
70+
[
71+
(DTypeSchema, np.integer, np.float32, r'.*float.*'),
72+
(DimsSchema, ('foo', 'bar'), ('foo',), r'.*length.*'),
73+
(DimsSchema, ('foo', 'bar'), ('foo', 'baz'), r'.*mismatch.*'),
74+
(ShapeSchema, (1, 2, None), (1, 2), r'.*number of dimensions.*'),
75+
(ShapeSchema, (1, 4, 4), (1, 3, 4), r'.*mismatch.*'),
76+
(NameSchema, 'foo', 'bar', r'.*name bar != foo.*'),
77+
(ArrayTypeSchema, np.ndarray, 'bar', r'.*array_type.*'),
78+
# schema_args for ChunksSchema include [chunks, dims, shape]
79+
(ChunksSchema, {'x': 3}, (((2, 2),), ('x',), (4,)), r'.*(3).*'),
80+
(ChunksSchema, {'x': (2, 1)}, (((2, 2),), ('x',), (4,)), r'.*(2, 1).*'),
81+
(ChunksSchema, True, (None, ('x',), (4,)), r'.*expected array to be chunked.*'),
82+
(
83+
ChunksSchema,
84+
False,
85+
(((2, 2),), ('x',), (4,)),
86+
r'.*expected unchunked array but it is chunked*',
87+
),
88+
(ChunksSchema, {'x': -1}, (((1, 2, 1),), ('x',), (4,)), r'.*did not match.*'),
89+
(ChunksSchema, {'x': 2}, (((2, 3, 2),), ('x',), (7,)), r'.*did not match.*'),
90+
(ChunksSchema, {'x': 2}, (((2, 2, 3),), ('x',), (7,)), r'.*did not match.*'),
91+
(ChunksSchema, {'x': 2, 'y': -1}, (((2, 2), (5, 5)), ('x', 'y'), (4, 10)), r'.*(5).*'),
92+
],
93+
)
94+
def test_component_raises_schema_error(component, schema_args, validate, match):
95+
schema = component(schema_args)
96+
with pytest.raises(SchemaError, match=match):
97+
if component in [ChunksSchema]: # special case construction
98+
schema.validate(*validate)
99+
else:
100+
schema.validate(validate)
101+
102+
103+
def test_chunks_schema_raises_for_invalid_chunks():
104+
with pytest.raises(ValueError, match=r'.*int.*'):
105+
schema = ChunksSchema(chunks=2)
106+
schema.validate(((2, 2),), ('x',), (4,))
7107

8108

9109
def test_dataarray_empty_constructor():
10110

111+
da = xr.DataArray(np.ones(4, dtype='i4'))
11112
da_schema = DataArraySchema()
12113
assert hasattr(da_schema, 'validate')
114+
assert da_schema.json == {}
115+
da_schema.validate(da)
116+
117+
118+
@pytest.mark.parametrize(
119+
'kind, component, schema_args',
120+
[
121+
('dtype', DTypeSchema, 'i4'),
122+
('dims', DimsSchema, ('x', None)),
123+
('shape', ShapeSchema, (2, None)),
124+
('name', NameSchema, 'foo'),
125+
('array_type', ArrayTypeSchema, np.ndarray),
126+
('chunks', ChunksSchema, False),
127+
],
128+
)
129+
def test_dataarray_component_constructors(kind, component, schema_args):
130+
da = xr.DataArray(np.zeros((2, 4), dtype='i4'), dims=('x', 'y'), name='foo')
131+
comp_schema = component(schema_args)
132+
schema = DataArraySchema(**{kind: schema_args})
133+
assert comp_schema.json == getattr(schema, kind).json
134+
assert isinstance(getattr(schema, kind), component)
13135

14-
15-
def test_dataarray_validate_dtype():
16-
17-
da = xr.DataArray(np.ones(4, dtype='i4'))
18-
schema = DataArraySchema(dtype='i4')
19-
schema.validate(da)
20-
21-
schema = DataArraySchema(dtype=np.int32)
22-
schema.validate(da)
23-
24-
schema = DataArraySchema(dtype=np.integer)
25-
schema.validate(da)
26-
27-
schema = DataArraySchema(dtype=np.floating)
28-
with pytest.raises(SchemaError, match=r'.*floating.*'):
29-
schema.validate(da)
30-
31-
32-
def test_dataarray_validate_name():
33-
34-
da = xr.DataArray(np.ones(4), name='foo')
35-
schema = DataArraySchema(name='foo')
36136
schema.validate(da)
37137

38-
schema = DataArraySchema(name='bar')
39-
with pytest.raises(SchemaError, match=r'.*foo.*'):
40-
schema.validate(da)
41-
42-
43-
def test_dataarray_validate_shape():
44138

45-
da = xr.DataArray(np.ones(4))
46-
schema = DataArraySchema(shape=(4,))
47-
schema.validate(da)
48-
49-
schema = DataArraySchema(shape=(4, 2))
50-
with pytest.raises(SchemaError, match=r'.*ndim.*'):
51-
schema.validate(da)
139+
def test_dataarray_schema_validate_raises_for_invalid_input_type():
140+
ds = xr.Dataset()
141+
schema = DataArraySchema()
142+
with pytest.raises(ValueError, match='Input must be a xarray.DataArray'):
143+
schema.validate(ds)
52144

53-
schema = DataArraySchema(shape=(3,))
54-
with pytest.raises(SchemaError, match=r'.*(4).*'):
55-
schema.validate(da)
56-
57-
58-
def test_dataarray_validate_dims():
59-
60-
da = xr.DataArray(np.ones(4), dims=['x'])
61-
schema = DataArraySchema(dims=['x'])
62-
schema.validate(da)
63145

64-
schema = DataArraySchema(dims=(['x', 'y']))
65-
with pytest.raises(SchemaError, match=r'.*length of dims.*'):
66-
schema.validate(da)
146+
def test_dataset_empty_constructor():
147+
ds_schema = DatasetSchema()
148+
assert hasattr(ds_schema, 'validate')
149+
ds_schema.json == {}
67150

68-
schema = DataArraySchema(dims=['y'])
69-
with pytest.raises(SchemaError, match=r'.*(y).*'):
70-
schema.validate(da)
71151

152+
def test_dataset_example(ds):
72153

73-
def test_dataarray_validate_array_type():
154+
ds_schema = DatasetSchema(
155+
{
156+
'foo': DataArraySchema(name='foo', dtype=np.int32, dims=['x']),
157+
'bar': DataArraySchema(name='bar', dtype=np.floating, dims=['x', 'y']),
158+
}
159+
)
160+
assert list(ds_schema.json['data_vars'].keys()) == ['foo', 'bar']
161+
ds_schema.validate(ds)
74162

75-
da = xr.DataArray(np.ones(4), dims=['x'])
76-
schema = DataArraySchema(array_type=np.ndarray)
77-
schema.validate(da)
163+
ds['foo'] = ds.foo.astype('float32')
164+
with pytest.raises(SchemaError, match='dtype'):
165+
ds_schema.validate(ds)
78166

79-
schema = DataArraySchema(array_type=float)
80-
with pytest.raises(SchemaError, match=r'.*(float).*'):
81-
schema.validate(da)
167+
ds = ds.drop_vars('foo')
168+
with pytest.raises(SchemaError, match='variable foo'):
169+
ds_schema.validate(ds)
82170

83171

84-
def test_dataarray_validate_chunks():
85-
pytest.importorskip('dask')
172+
def test_checks_ds(ds):
173+
def check_foo(ds):
174+
assert 'foo' in ds
86175

87-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': 2})
88-
schema = DataArraySchema(chunks={'x': 2})
89-
schema.validate(da)
176+
ds_schema = DatasetSchema(checks=[check_foo])
177+
ds_schema.validate(ds)
90178

91-
schema = DataArraySchema(chunks={'x': (2, 2)})
92-
schema.validate(da)
179+
ds = ds.drop_vars('foo')
180+
with pytest.raises(AssertionError):
181+
ds_schema.validate(ds)
93182

94-
schema = DataArraySchema(chunks={'x': [2, 2]})
95-
schema.validate(da)
183+
ds_schema = DatasetSchema(checks=[])
184+
ds_schema.validate(ds)
96185

97-
schema = DataArraySchema(chunks={'x': 3})
98-
with pytest.raises(SchemaError, match=r'.*(3).*'):
99-
schema.validate(da)
186+
# TODO
187+
# with pytest.raises(ValueError):
188+
# DatasetSchema(checks=[2])
100189

101-
schema = DataArraySchema(chunks={'x': (2, 1)})
102-
with pytest.raises(SchemaError, match=r'.*(2, 1).*'):
103-
schema.validate(da)
104190

105-
# check that when expected chunk == -1 it fails
106-
schema = DataArraySchema(chunks={'x': -1})
107-
with pytest.raises(SchemaError, match=r'.*(4).*'):
108-
schema.validate(da)
191+
def test_checks_da(ds):
192+
da = ds['foo']
109193

110-
# check that when chunking schema is -1 it also works
111-
# both when chunking is specified as -1 and as 4
112-
schema = DataArraySchema(chunks={'x': 4})
113-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
114-
schema.validate(da)
194+
def check_foo(da):
195+
assert da.name == 'foo'
115196

116-
schema = DataArraySchema(chunks={'x': -1})
117-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': 4})
118-
schema.validate(da)
197+
def check_bar(da):
198+
assert da.name == 'bar'
119199

120-
schema = DataArraySchema(chunks={'x': -1})
121-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
200+
schema = DataArraySchema(checks=[check_foo])
122201
schema.validate(da)
123202

124-
# test for agnostic chunks
125-
schema = DataArraySchema(chunks=True)
126-
da = xr.DataArray(np.ones(4), dims=['x'])
127-
with pytest.raises(SchemaError, match='Schema expected DataArray to be chunked but it is not'):
203+
schema = DataArraySchema(checks=[check_bar])
204+
with pytest.raises(AssertionError):
128205
schema.validate(da)
129206

130-
# now try passing an irregularly chunked data array
131-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1, 2, 1)})
207+
schema = DataArraySchema(checks=[])
132208
schema.validate(da)
133209

134-
# test the check for regular chunk sizes
135-
schema = DataArraySchema(chunks={'x': -1})
136-
with pytest.raises(AssertionError, match=r'.*(gracious).*'):
137-
schema.validate(da)
138-
139-
140-
def test_dataset_empty_constructor():
141-
ds_schema = DatasetSchema()
142-
assert hasattr(ds_schema, 'validate')
143-
144-
145-
def test_dataset_example():
146-
147-
ds = xr.Dataset(
148-
{
149-
'x': xr.DataArray(np.arange(4) - 2, dims='x'),
150-
'foo': xr.DataArray(np.ones(4, dtype='i4'), dims='x'),
151-
'bar': xr.DataArray(np.arange(8, dtype=np.float32).reshape(4, 2), dims=('x', 'y')),
152-
}
153-
)
154-
155-
ds_schema = DatasetSchema(
156-
{
157-
'foo': DataArraySchema(name='foo', dtype=np.int32, dims=['x']),
158-
'bar': DataArraySchema(name='bar', dtype=np.floating, dims=['x', 'y']),
159-
}
160-
)
161-
ds_schema.validate(ds)
162-
163-
164-
def test_validate():
165-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1, 2, 1)})
166-
schema = DataArraySchema(chunks=False)
167-
# check that da is unchunked
168-
with pytest.raises(SchemaError, match='Schema expected unchunked DataArray but it is chunked!'):
169-
schema.validate(da)
170-
da = xr.DataArray(np.ones(4), dims=['x'])
171-
schema.validate(da)
210+
with pytest.raises(ValueError):
211+
DataArraySchema(checks=[2])

xarray_schema/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from pkg_resources import DistributionNotFound, get_distribution
22

3-
from .core import DataArraySchema, DatasetSchema # noqa: F401
3+
from .components import ( # noqa: F401
4+
ArrayTypeSchema,
5+
ChunksSchema,
6+
DimsSchema,
7+
DTypeSchema,
8+
NameSchema,
9+
ShapeSchema,
10+
)
11+
from .dataarray import DataArraySchema # noqa: F401
12+
from .dataset import DatasetSchema # noqa: F401
413

514
try:
615
__version__ = get_distribution(__name__).version

0 commit comments

Comments
 (0)