Skip to content

Commit da42a9c

Browse files
author
cmdupuis3
committed
More squeeze_batch_dim tests; fix bug
1 parent fb29cba commit da42a9c

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

xbatcher/generators.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ def _maybe_stack_batch_dims(
5858
ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'
5959
):
6060
batch_dims = [d for d in ds.dims if d not in input_dims]
61-
if len(batch_dims) < 2:
61+
if len(batch_dims) == 0:
6262
if squeeze_batch_dim:
6363
return ds
6464
else:
6565
return ds.expand_dims(stacked_dim_name, 0)
66-
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
67-
# ensure correct order
68-
dim_order = (stacked_dim_name,) + tuple(input_dims)
69-
return ds_stack.transpose(*dim_order)
66+
elif len(batch_dims) == 1:
67+
return ds
68+
else:
69+
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
70+
# ensure correct order
71+
dim_order = (stacked_dim_name,) + tuple(input_dims)
72+
return ds_stack.transpose(*dim_order)
7073

7174

7275
class BatchGenerator:

xbatcher/tests/test_generators.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,46 @@ def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
189189
assert list(ds_batch['foo'].shape) == [xbsize]
190190

191191

192+
@pytest.mark.parametrize('bsize', [5, 10])
193+
def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
194+
xbsize = 20
195+
bg = BatchGenerator(
196+
sample_ds_3d,
197+
input_dims={'y': bsize, 'x': xbsize},
198+
squeeze_batch_dim=False,
199+
)
200+
for ds_batch in bg:
201+
assert list(ds_batch['foo'].shape) == [10, bsize, xbsize]
202+
203+
bg2 = BatchGenerator(
204+
sample_ds_3d,
205+
input_dims={'y': bsize, 'x': xbsize},
206+
squeeze_batch_dim=True,
207+
)
208+
for ds_batch in bg2:
209+
assert list(ds_batch['foo'].shape) == [10, bsize, xbsize]
210+
211+
212+
@pytest.mark.parametrize('bsize', [5, 10])
213+
def test_batch_3d_squeeze_batch_dim2(sample_ds_3d, bsize):
214+
xbsize = 20
215+
bg = BatchGenerator(
216+
sample_ds_3d,
217+
input_dims={'x': xbsize},
218+
squeeze_batch_dim=False,
219+
)
220+
for ds_batch in bg:
221+
assert list(ds_batch['foo'].shape) == [500, xbsize]
222+
223+
bg2 = BatchGenerator(
224+
sample_ds_3d,
225+
input_dims={'x': xbsize},
226+
squeeze_batch_dim=True,
227+
)
228+
for ds_batch in bg2:
229+
assert list(ds_batch['foo'].shape) == [500, xbsize]
230+
231+
192232
def test_preload_batch_false(sample_ds_1d):
193233
sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2})
194234
bg = BatchGenerator(

0 commit comments

Comments
 (0)