Skip to content

Commit 70ee97c

Browse files
Oriana ChegwiddenOriana Chegwidden
authored andcommitted
Merge remote-tracking branch 'origin/main' into chunks
2 parents bcdb7a9 + 0a60625 commit 70ee97c

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
dask
12
pytest
23
pytest-cov
34
-r requirements.txt

tests/test_core.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ def test_dataarray_validate_array_type():
8181
schema.validate(da)
8282

8383

84+
def test_dataarray_validate_chunks():
85+
pytest.importorskip('dask')
86+
87+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': 2})
88+
schema = DataArraySchema(chunks={'x': 2})
89+
schema.validate(da)
90+
91+
schema = DataArraySchema(chunks={'x': (2, 2)})
92+
schema.validate(da)
93+
94+
schema = DataArraySchema(chunks={'x': [2, 2]})
95+
schema.validate(da)
96+
97+
schema = DataArraySchema(chunks={'x': 3})
98+
with pytest.raises(SchemaError, match=r'.*(3).*'):
99+
schema.validate(da)
100+
101+
schema = DataArraySchema(chunks={'x': (2, 1)})
102+
with pytest.raises(SchemaError, match=r'.*(2, 1).*'):
103+
schema.validate(da)
104+
105+
84106
def test_dataset_empty_constructor():
85107
ds_schema = DatasetSchema()
86108
assert hasattr(ds_schema, 'validate')

xarray_schema/core.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
7272
------
7373
SchemaError
7474
'''
75+
assert isinstance(da, xr.core.dataarray.DataArray),'Input is not an xarray DataArray and schema for chunks are not yet implemented'
7576

7677
if self.dtype is not None and not np.issubdtype(da.dtype, self.dtype):
7778
raise SchemaError(f'dtype {da.dtype} != {self.dtype}')
@@ -103,14 +104,21 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
103104
raise NotImplementedError('coords schema not implemented yet')
104105

105106
if self.chunks:
106-
# ensure that the chunks are what you want them to be
107-
for dim, expected in self.chunks.items():
108-
# for special case of chunksize=-1, make the expected equal to the full length of that dimension
109-
if expected == -1:
110-
expected = len(da[dim])
111-
actual = da.chunks[dim][0]
112-
if actual != expected:
113-
raise SchemaError(f'chunk mismatch for dimension {dim}: {actual} != {expected}')
107+
dim_chunks = dict(zip(da.dims, da.chunks))
108+
for key, ec in self.chunks.items():
109+
if isinstance(ec, int):
110+
# handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
111+
if ec==-1:
112+
ec = len(da[key])
113+
# grab the first entry in da's tuple of chunks to be representative (as it should be assuming they're regular)
114+
ac = dim_chunks[key][0]
115+
if ac != ec:
116+
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
117+
118+
else: # assumes ec is an iterable
119+
ac = dim_chunks[key]
120+
if tuple(ac) != tuple(ec):
121+
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
114122

115123
if self.attrs:
116124
raise NotImplementedError('attrs schema not implemented yet')
@@ -122,8 +130,6 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
122130
for check in self.checks:
123131
da = check(da)
124132

125-
return da
126-
127133

128134
class DatasetSchema:
129135
'''A light-weight xarray.Dataset validator

0 commit comments

Comments
 (0)