Skip to content

Commit 6758211

Browse files
committed
merging main
2 parents a57a3ad + 2689bc4 commit 6758211

File tree

4 files changed

+79
-12
lines changed

4 files changed

+79
-12
lines changed

.github/workflows/main.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
with:
3030
python-version: ${{ matrix.python-version }}
3131
architecture: x64
32-
- uses: actions/[email protected].6
32+
- uses: actions/[email protected].7
3333
with:
3434
path: ~/.cache/pip
3535
key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.txt') }}
@@ -62,7 +62,7 @@ jobs:
6262
with:
6363
python-version: ${{ matrix.python-version }}
6464
architecture: x64
65-
- uses: actions/[email protected].6
65+
- uses: actions/[email protected].7
6666
with:
6767
path: ~/.cache/pip
6868
key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.txt') }}

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ select = B,C,E,F,W,T4,B9
88

99
[isort]
1010
known_first_party=xarray_schema
11-
known_third_party=numpy,pkg_resources,pytest,setuptools,xarray
11+
known_third_party=dask,numpy,pkg_resources,pytest,setuptools,xarray
1212
multi_line_output=3
1313
include_trailing_comma=True
1414
force_grid_wrap=0

tests/test_core.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,40 @@ def test_dataarray_validate_chunks():
102102
with pytest.raises(SchemaError, match=r'.*(2, 1).*'):
103103
schema.validate(da)
104104

105+
# check that when expected chunk == -1 it fails
106+
schema = DataArraySchema(chunks={'x': -1})
107+
with pytest.raises(SchemaError, match=r'.*(4).*'):
108+
schema.validate(da)
109+
110+
# check that when chunking schema is -1 it also works
111+
# both when chunking is specified as -1 and as 4
112+
schema = DataArraySchema(chunks={'x': 4})
113+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
114+
schema.validate(da)
115+
116+
schema = DataArraySchema(chunks={'x': -1})
117+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': 4})
118+
schema.validate(da)
119+
120+
schema = DataArraySchema(chunks={'x': -1})
121+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': -1})
122+
schema.validate(da)
123+
124+
# test for agnostic chunks
125+
schema = DataArraySchema(chunks=True)
126+
da = xr.DataArray(np.ones(4), dims=['x'])
127+
with pytest.raises(SchemaError, match='Schema expected DataArray to be chunked but it is not'):
128+
schema.validate(da)
129+
130+
# now try passing an irregularly chunked data array
131+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1, 2, 1)})
132+
schema.validate(da)
133+
134+
# test the check for regular chunk sizes
135+
schema = DataArraySchema(chunks={'x': -1})
136+
with pytest.raises(AssertionError, match=r'.*(gracious).*'):
137+
schema.validate(da)
138+
105139

106140
def test_dataset_empty_constructor():
107141
ds_schema = DatasetSchema()
@@ -125,3 +159,13 @@ def test_dataset_example():
125159
}
126160
)
127161
ds_schema.validate(ds)
162+
163+
164+
def test_validate():
165+
da = xr.DataArray(np.ones(4), dims=['x']).chunk({'x': (1, 2, 1)})
166+
schema = DataArraySchema(chunks=False)
167+
# check that da is unchunked
168+
with pytest.raises(SchemaError, match='Schema expected unchunked DataArray but it is chunked!'):
169+
schema.validate(da)
170+
da = xr.DataArray(np.ones(4), dims=['x'])
171+
schema.validate(da)

xarray_schema/core.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import xarray as xr
5+
from dask.array.core import _check_regular_chunks
56

67
# TODOs:
78
# - api grouping, should the constructors look similar to the DataArray/Dataset constructors
@@ -24,6 +25,9 @@ class DataArraySchema:
2425
Shape of the DataArray. `None` may be used as a wildcard value. By default None
2526
dims : Tuple[Union[Hashable, None]], optional
2627
Dimensions of the DataArray. `None` may be used as a wildcard value. By default None
28+
chunks : Union[bool, Dict[Hashable, Union[int, None]]], optional
29+
If bool, specifies whether DataArray is chunked or not, agnostic to chunk sizes.
30+
If dict, includes the expected chunks for the DataArray, by default None
2731
name : str, optional
2832
Name of the DataArray, by default None
2933
array_type : Any, optional
@@ -37,7 +41,7 @@ def __init__(
3741
shape: Tuple[Union[int, None]] = None,
3842
dims: Tuple[Union[Hashable, None]] = None,
3943
coords: Dict[Hashable, Any] = None,
40-
chunks: Dict[Hashable, Union[int, None]] = None,
44+
chunks: Union[bool, Dict[Hashable, Union[int, None]]] = None,
4145
name: str = None,
4246
array_type: Any = None,
4347
attrs: Dict[Hashable, Any] = None,
@@ -72,6 +76,12 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
7276
------
7377
SchemaError
7478
'''
79+
if not isinstance(da, xr.DataArray):
80+
raise ValueError('Input must be a xarray.DataArray')
81+
82+
if self.chunks is not None:
83+
if self.chunks is False and da.chunks:
84+
raise SchemaError('Schema expected unchunked DataArray but it is chunked!')
7585

7686
if self.dtype is not None and not np.issubdtype(da.dtype, self.dtype):
7787
raise SchemaError(f'dtype {da.dtype} != {self.dtype}')
@@ -103,17 +113,30 @@ def validate(self, da: xr.DataArray) -> xr.DataArray:
103113
raise NotImplementedError('coords schema not implemented yet')
104114

105115
if self.chunks:
106-
dim_chunks = dict(zip(da.dims, da.chunks))
107-
for key, ec in self.chunks.items():
108-
if isinstance(ec, int):
109-
for ac in dim_chunks[key][:-1]:
116+
if self.chunks is True:
117+
if not da.chunks:
118+
raise SchemaError('Schema expected DataArray to be chunked but it is not')
119+
120+
else:
121+
assert type(self.chunks) == dict, 'Must pass chunks information as dictionary'
122+
dim_chunks = dict(zip(da.dims, da.chunks))
123+
dim_sizes = dict(zip(da.dims, da.shape))
124+
# check whether chunk sizes are regular because we assume the first chunk to be representative below
125+
assert _check_regular_chunks(da.chunks), 'Good gracious no! Chunks are not regular!'
126+
for key, ec in self.chunks.items():
127+
if isinstance(ec, int):
128+
# handles case of expected chunksize is shorthand of -1 which translates to the full length of dimension
129+
if ec < 0:
130+
ec = dim_sizes[key]
131+
# grab the first entry in da's tuple of chunks to be representative (since we've checked above that they're regular)
132+
ac = dim_chunks[key][0]
110133
if ac != ec:
111134
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
112135

113-
else: # assumes ec is an iterable
114-
ac = dim_chunks[key]
115-
if tuple(ac) != tuple(ec):
116-
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
136+
else: # assumes ec is an iterable
137+
ac = dim_chunks[key]
138+
if tuple(ac) != tuple(ec):
139+
raise SchemaError(f'{key} chunks did not match: {ac} != {ec}')
117140

118141
if self.attrs:
119142
raise NotImplementedError('attrs schema not implemented yet')

0 commit comments

Comments
 (0)