Skip to content

Commit 8f7dd7e

Browse files
committed
Adding tests for chunks validation
1 parent 10416d4 commit 8f7dd7e

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

tests/test_core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ def test_dataarray_validate_chunks():
101101
schema = DataArraySchema(chunks={'x': (2, 1)})
102102
with pytest.raises(SchemaError, match=r'.*(2, 1).*'):
103103
schema.validate(da)
104+
105+
# check that when expected chunk == -1 it fails
106+
schema = DataArraySchema(chunks={'x': -1})
107+
with pytest.raises(SchemaError, match=r'.*(4).*'):
108+
schema.validate(da)
109+
110+
# check that when chunking schema is -1 it also works
111+
# both when chunking is specified as -1 and as 4
112+
schema = DataArraySchema(chunks={'x': 4})
113+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
114+
schema.validate(da)
115+
116+
schema = DataArraySchema(chunks={'x': -1})
117+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': 4})
118+
schema.validate(da)
119+
120+
schema = DataArraySchema(chunks={'x': -1})
121+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
122+
schema.validate(da)
123+
104124

105125

106126
def test_dataset_empty_constructor():
@@ -125,3 +145,9 @@ def test_dataset_example():
125145
}
126146
)
127147
ds_schema.validate(ds)
148+
149+
def test_validate():
150+
schema = DataArraySchema()
151+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1,2,1)})
152+
with pytest.raises(AssertionError, match=r'.*(gracious).*'):
153+
schema.validate(da)

xarray_schema/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import xarray as xr
5-
5+
from dask.array.core import _check_regular_chunks
66
# TODOs:
77
# - api grouping, should the constructors look similar to the DataArray/Dataset constructors
88

@@ -72,6 +72,8 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
7272
------
7373
SchemaError
7474
'''
75+
if da.chunks:
76+
assert _check_regular_chunks(da.chunks), 'Good gracious no! Chunks are not regular!'
7577
if not isinstance(da, xr.DataArray):
7678
raise ValueError('Input must be a xarray.DataArray')
7779

@@ -106,11 +108,12 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
106108

107109
if self.chunks:
108110
dim_chunks = dict(zip(da.dims, da.chunks))
111+
dim_sizes = dict(zip(da.dims, da.shape))
109112
for key, ec in self.chunks.items():
110113
if isinstance(ec, int):
111114
# handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
112-
if ec==-1:
113-
ec = len(da[key])
115+
if ec == -1:
116+
ec = dim_sizes[key]
114117
# grab the first entry in da's tuple of chunks to be representative (as it should be assuming they're regular)
115118
ac = dim_chunks[key][0]
116119
if ac != ec:

0 commit comments

Comments
 (0)