Skip to content

Commit 05b978c

Browse files
author
Joe Hamman
authored
Merge pull request #63 from meghanrjones/slice-exceptions
Use exceptions rather than assert statements for generator
2 parents f8c45db + 34c4c1a commit 05b978c

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

xbatcher/generators.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ def _as_xarray_dataset(ds):
1717

1818
def _slices(dimsize, size, overlap=0):
1919
# return a list of slices to chop up a single dimension
20+
if overlap >= size:
21+
raise ValueError(
22+
'input overlap must be less than the input sample length, but '
23+
f'the input sample length is {size} and the overlap is {overlap}'
24+
)
2025
slices = []
2126
stride = size - overlap
22-
assert stride > 0
23-
assert stride <= dimsize
2427
for start in range(0, dimsize, stride):
2528
end = start + size
2629
if end <= dimsize:
@@ -34,6 +37,13 @@ def _iterate_through_dataset(ds, dims, overlap={}):
3437
dimsize = ds.dims[dim]
3538
size = dims[dim]
3639
olap = overlap.get(dim, 0)
40+
if size > dimsize:
41+
raise ValueError(
42+
'input sample length must be less than or equal to the '
43+
f'dimension length, but the sample length of {size} '
44+
f'is greater than the dimension length of {dimsize} '
45+
f'for {dim}'
46+
)
3747
dim_slices.append(_slices(dimsize, size, olap))
3848

3949
for slices in itertools.product(*dim_slices):

xbatcher/tests/test_generators.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,16 @@ def test_preload_batch_true(sample_ds_1d):
211211
for ds_batch in bg:
212212
assert isinstance(ds_batch, xr.Dataset)
213213
assert not ds_batch.chunks
214+
215+
216+
def test_batch_exceptions(sample_ds_1d):
217+
# ValueError when input_dim[dim] > ds.sizes[dim]
218+
with pytest.raises(ValueError) as e:
219+
BatchGenerator(sample_ds_1d, input_dims={'x': 110})
220+
assert len(e) == 1
221+
# ValueError when input_overlap[dim] > input_dim[dim]
222+
with pytest.raises(ValueError) as e:
223+
BatchGenerator(
224+
sample_ds_1d, input_dims={'x': 10}, input_overlap={'x': 20}
225+
)
226+
assert len(e) == 1

0 commit comments

Comments
 (0)