Skip to content

Commit 0e8f716

Browse files
author
cmdupuis3
committed
squeeze_batch_dim test sketch
1 parent c61846d commit 0e8f716

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

xbatcher/tests/test_generators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
160160
* (sample_ds_3d.dims['y'] // bsize)
161161
* sample_ds_3d.dims['time']
162162
)
163+
164+
165+
@pytest.mark.parametrize('bsize', [5, 10])
166+
def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
167+
xbsize = 20
168+
bg = BatchGenerator(
169+
sample_ds_3d,
170+
input_dims={'y': bsize, 'x': xbsize},
171+
squeeze_batch_dim=False,
172+
)
173+
for ds_batch in bg:
174+
assert ds_batch['x'].shape == [1, bsize, xbsize]
175+
176+
bg2 = BatchGenerator(
177+
sample_ds_3d,
178+
input_dims={'y': bsize, 'x': xbsize},
179+
squeeze_batch_dim=True,
180+
)
181+
for ds_batch in bg:
182+
assert ds_batch['x'].shape == [bsize, xbsize]

0 commit comments

Comments
 (0)