Skip to content

Commit d3e1c2f

Browse files
author
Joe Hamman
authored
Merge pull request #66 from meghanrjones/3dtest-update
Test against expected batch in test_batch_3d_2d_input
2 parents 05b978c + 1f32348 commit d3e1c2f

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

xbatcher/tests/test_generators.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
133133

134134
@pytest.mark.parametrize('bsize', [5, 10])
135135
def test_batch_3d_1d_input(sample_ds_3d, bsize):
136-
137136
# first do the iteration over just one dimension
138137
bg = BatchGenerator(sample_ds_3d, input_dims={'x': bsize})
139138
for n, ds_batch in enumerate(bg):
@@ -164,8 +163,19 @@ def test_batch_3d_2d_input(sample_ds_3d, bsize):
164163
assert isinstance(ds_batch, xr.Dataset)
165164
assert ds_batch.dims['x'] == xbsize
166165
assert ds_batch.dims['y'] == bsize
167-
# TODO? Is it worth it to try to reproduce the internal logic of the
168-
# generator and verify that the slices are correct?
166+
yn, xn = np.unravel_index(
167+
n,
168+
(
169+
(sample_ds_3d.dims['y'] // bsize),
170+
(sample_ds_3d.dims['x'] // xbsize),
171+
),
172+
)
173+
expected_xslice = slice(xbsize * xn, xbsize * (xn + 1))
174+
expected_yslice = slice(bsize * yn, bsize * (yn + 1))
175+
ds_batch_expected = sample_ds_3d.isel(
176+
x=expected_xslice, y=expected_yslice
177+
)
178+
xr.testing.assert_equal(ds_batch_expected, ds_batch)
169179
assert (n + 1) == (
170180
(sample_ds_3d.dims['x'] // xbsize) * (sample_ds_3d.dims['y'] // bsize)
171181
)

0 commit comments

Comments
 (0)