Skip to content

Commit d4eb889

Browse files
committed
adding in bool logic to chunks api
1 parent 87bd449 commit d4eb889

File tree

3 files changed

+59
-28
lines changed

3 files changed

+59
-28
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ select = B,C,E,F,W,T4,B9
88

99
[isort]
1010
known_first_party=xarray_schema
11-
known_third_party=numpy,pkg_resources,pytest,setuptools,xarray
11+
known_third_party=dask,numpy,pkg_resources,pytest,setuptools,xarray
1212
multi_line_output=3
1313
include_trailing_comma=True
1414
force_grid_wrap=0

tests/test_core.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def test_dataarray_validate_chunks():
101101
schema = DataArraySchema(chunks={'x': (2, 1)})
102102
with pytest.raises(SchemaError, match=r'.*(2, 1).*'):
103103
schema.validate(da)
104-
105-
# check that when expected chunk == -1 it fails
104+
105+
# check that when expected chunk == -1 it fails
106106
schema = DataArraySchema(chunks={'x': -1})
107107
with pytest.raises(SchemaError, match=r'.*(4).*'):
108108
schema.validate(da)
@@ -121,6 +121,20 @@ def test_dataarray_validate_chunks():
121121
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
122122
schema.validate(da)
123123

124+
# test for agnostic chunks
125+
schema = DataArraySchema(chunks=True)
126+
da = xr.DataArray(np.ones(4), dims=['x'])
127+
with pytest.raises(SchemaError, match='Schema expected DataArray to be chunked but it is not'):
128+
schema.validate(da)
129+
130+
# now try passing an irregularly chunked data array
131+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1, 2, 1)})
132+
schema.validate(da)
133+
134+
# test the check for regular chunk sizes
135+
schema = DataArraySchema(chunks={'x': -1})
136+
with pytest.raises(AssertionError, match=r'.*(gracious).*'):
137+
schema.validate(da)
124138

125139

126140
def test_dataset_empty_constructor():
@@ -146,8 +160,12 @@ def test_dataset_example():
146160
)
147161
ds_schema.validate(ds)
148162

163+
149164
def test_validate():
150-
schema = DataArraySchema()
151-
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1,2,1)})
152-
with pytest.raises(AssertionError, match=r'.*(gracious).*'):
153-
schema.validate(da)
165+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1, 2, 1)})
166+
schema = DataArraySchema(chunks=False)
167+
# check that da is unchunked
168+
with pytest.raises(SchemaError, match='Schema expected unchunked DataArray but it is chunked!'):
169+
schema.validate(da)
170+
da = xr.DataArray(np.ones(4), dims=['x'])
171+
schema.validate(da)

xarray_schema/core.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import xarray as xr
55
from dask.array.core import _check_regular_chunks
6+
67
# TODOs:
78
# - api grouping, should the constructors look similar to the DataArray/Dataset constructors
89

@@ -24,6 +25,9 @@ class DataArraySchema:
2425
Shape of the DataArray. `None` may be used as a wildcard value. By default None
2526
dims : Tuple[Union[Hashable, None]], optional
2627
Dimensions of the DataArray. `None` may be used as a wildcard value. By default None
28+
chunks : Union[bool, Dict[Hashable, Union[int, None]]], optional
29+
If bool, specifies whether DataArray is chunked or not, agnostic to chunk sizes.
30+
If dict, includes the expected chunks for the DataArray, by default None
2731
name : str, optional
2832
Name of the DataArray, by default None
2933
array_type : Any, optional
@@ -37,7 +41,7 @@ def __init__(
3741
shape: Tuple[Union[int, None]] = None,
3842
dims: Tuple[Union[Hashable, None]] = None,
3943
coords: Dict[Hashable, Any] = None,
40-
chunks: Dict[Hashable, Union[int, None]] = None,
44+
chunks: Union[bool, Dict[Hashable, Union[int, None]]] = None,
4145
name: str = None,
4246
array_type: Any = None,
4347
attrs: Dict[Hashable, Any] = None,
@@ -74,9 +78,10 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
7478
'''
7579
if not isinstance(da, xr.DataArray):
7680
raise ValueError('Input must be a xarray.DataArray')
77-
78-
if da.chunks:
79-
assert _check_regular_chunks(da.chunks), 'Good gracious no! Chunks are not regular!'
81+
82+
if self.chunks is not None:
83+
if self.chunks is False and da.chunks:
84+
raise SchemaError('Schema expected unchunked DataArray but it is chunked!')
8085

8186
if self.dtype is not None and not np.issubdtype(da.dtype, self.dtype):
8287
raise SchemaError(f'dtype {da.dtype} != {self.dtype}')
@@ -108,22 +113,30 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
108113
raise NotImplementedError('coords schema not implemented yet')
109114

110115
if self.chunks:
111-
dim_chunks = dict(zip(da.dims, da.chunks))
112-
dim_sizes = dict(zip(da.dims, da.shape))
113-
for key, ec in self.chunks.items():
114-
if isinstance(ec, int):
115-
# handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
116-
if ec < 0:
117-
ec = dim_sizes[key]
118-
# grab the first entry in da's tuple of chunks to be representative (as it should be assuming they're regular)
119-
ac = dim_chunks[key][0]
120-
if ac != ec:
121-
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
122-
123-
else: # assumes ec is an iterable
124-
ac = dim_chunks[key]
125-
if tuple(ac) != tuple(ec):
126-
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
116+
if self.chunks is True:
117+
if not da.chunks:
118+
raise SchemaError('Schema expected DataArray to be chunked but it is not')
119+
120+
else:
121+
assert type(self.chunks) == dict, 'Must pass chunks information as dictionary'
122+
dim_chunks = dict(zip(da.dims, da.chunks))
123+
dim_sizes = dict(zip(da.dims, da.shape))
124+
# check whether chunk sizes are regular because we assume the first chunk to be representative below
125+
assert _check_regular_chunks(da.chunks), 'Good gracious no! Chunks are not regular!'
126+
for key, ec in self.chunks.items():
127+
if isinstance(ec, int):
128+
# handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
129+
if ec < 0:
130+
ec = dim_sizes[key]
131+
# grab the first entry in da's tuple of chunks to be representative (since we've checked above that they're regular)
132+
ac = dim_chunks[key][0]
133+
if ac != ec:
134+
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
135+
136+
else: # assumes ec is an iterable
137+
ac = dim_chunks[key]
138+
if tuple(ac) != tuple(ec):
139+
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
127140

128141
if self.attrs:
129142
raise NotImplementedError('attrs schema not implemented yet')
@@ -134,7 +147,7 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
134147
if self.checks:
135148
for check in self.checks:
136149
da = check(da)
137-
150+
138151
return da
139152

140153

0 commit comments

Comments
 (0)