Skip to content

Commit 3e41939

Browse files
committed
added some new tests
1 parent 1e6c0f5 commit 3e41939

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

xbatcher/generators.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,13 @@ def __iter__(self):
109109
ds_batch.load()
110110
input_generator = self._iterate_input_dims(ds_batch)
111111
if self.concat_input_dims:
112-
all_dsets = [_drop_input_dims(ds_input, list(self.input_dims))
112+
new_dim_suffix = '_input'
113+
all_dsets = [_drop_input_dims(ds_input, list(self.input_dims),
114+
suffix=new_dim_suffix)
113115
for ds_input in input_generator]
114-
dsc = xr.concat(all_batches, dim='input_batch')
115-
yield _maybe_stack_batch_dims(dsc, list(self.input_dims))
116+
dsc = xr.concat(all_dsets, dim='input_batch')
117+
new_input_dims = [dim + new_dim_suffix for dim in self.input_dims]
118+
yield _maybe_stack_batch_dims(dsc, new_input_dims)
116119
else:
117120
for ds_input in input_generator:
118121
yield _maybe_stack_batch_dims(ds_input, list(self.input_dims))

xbatcher/tests/test_generators.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,57 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
4242

4343

4444
@pytest.fixture(scope='module')
45-
def sample_ds_2d():
46-
shape = (50, 100)
47-
ds = xr.Dataset({'foo': (['y', 'x'], np.random.rand(*shape)),
48-
'bar': (['y', 'x'], np.random.randint(0, 10, shape))},
45+
def sample_ds_3d():
46+
shape = (10, 50, 100)
47+
ds = xr.Dataset({'foo': (['time', 'y', 'x'], np.random.rand(*shape)),
48+
'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape))},
4949
{'x': (['x'], np.arange(shape[-1])),
5050
'y': (['y'], np.arange(shape[-2]))})
5151
return ds
5252

5353

5454
@pytest.mark.parametrize("bsize", [5, 10])
55-
def test_batch_2d(sample_ds_2d, bsize):
55+
def test_batch_3d_1d_input(sample_ds_3d, bsize):
5656

5757
# first do the iteration over just one dimension
58-
bg = BatchGenerator(sample_ds_2d, input_dims={'x': bsize})
58+
bg = BatchGenerator(sample_ds_3d, input_dims={'x': bsize})
5959
for n, ds_batch in enumerate(bg):
6060
assert isinstance(ds_batch, xr.Dataset)
6161
assert ds_batch.dims['x'] == bsize
62-
assert ds_batch.dims['y'] == sample_ds_2d.dims['y']
62+
# time and y should be collapsed into batch dimension
63+
assert ds_batch.dims['batch'] == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time']
6364
expected_slice = slice(bsize*n, bsize*(n+1))
64-
ds_batch_expected = sample_ds_2d.isel(x=expected_slice)
65+
ds_batch_expected = (sample_ds_3d.isel(x=expected_slice)
66+
.stack(batch=['y', 'time'])
67+
.transpose('batch', 'x'))
68+
print(ds_batch)
69+
print(ds_batch_expected)
6570
assert ds_batch.equals(ds_batch_expected)
6671

72+
@pytest.mark.parametrize("bsize", [5, 10])
73+
def test_batch_3d_2d_input(sample_ds_3d, bsize):
6774
# now iterate over both
6875
xbsize = 20
69-
bg = BatchGenerator(sample_ds_2d, input_dims={'y': bsize, 'x': xbsize})
76+
bg = BatchGenerator(sample_ds_3d, input_dims={'y': bsize, 'x': xbsize})
7077
for n, ds_batch in enumerate(bg):
7178
assert isinstance(ds_batch, xr.Dataset)
7279
assert ds_batch.dims['x'] == xbsize
7380
assert ds_batch.dims['y'] == bsize
7481
# TODO? Is it worth it to try to reproduce the internal logic of the
7582
# generator and verify that the slices are correct?
76-
assert (n+1)==((sample_ds_2d.dims['x']//xbsize) * (sample_ds_2d.dims['y']//bsize))
83+
assert (n+1)==((sample_ds_3d.dims['x']//xbsize) * (sample_ds_3d.dims['y']//bsize))
84+
85+
86+
@pytest.mark.parametrize("bsize", [5, 10])
87+
def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
88+
# now iterate over both
89+
xbsize = 20
90+
bg = BatchGenerator(sample_ds_3d, input_dims={'y': bsize, 'x': xbsize},
91+
concat_input_dims=True)
92+
for n, ds_batch in enumerate(bg):
93+
assert isinstance(ds_batch, xr.Dataset)
94+
assert ds_batch.dims['x_input'] == xbsize
95+
assert ds_batch.dims['y_input'] == bsize
96+
assert ds_batch.dims['batch'] == ((sample_ds_3d.dims['x']//xbsize) *
97+
(sample_ds_3d.dims['y']//bsize) *
98+
sample_ds_3d.dims['time'])

0 commit comments

Comments
 (0)