Skip to content

Commit 142031d

Browse files
author
cmdupuis3
committed
Fix 1D squeeze_batch_dim test
1 parent 749ac26 commit 142031d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

xbatcher/tests/test_generators.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,20 +163,20 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
163163

164164

165165
@pytest.mark.parametrize('bsize', [5, 10])
166-
def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
166+
def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
167167
xbsize = 20
168168
bg = BatchGenerator(
169-
sample_ds_3d,
170-
input_dims={'time': 1, 'y': bsize, 'x': xbsize},
169+
sample_ds_1d,
170+
input_dims={'x': xbsize},
171171
squeeze_batch_dim=False,
172172
)
173173
for ds_batch in bg:
174-
assert ds_batch['x'].shape == [1, bsize, xbsize]
174+
assert list(ds_batch['foo'].shape) == [1, xbsize]
175175

176176
bg2 = BatchGenerator(
177-
sample_ds_3d,
178-
input_dims={'time': 1, 'y': bsize, 'x': xbsize},
177+
sample_ds_1d,
178+
input_dims={'x': xbsize},
179179
squeeze_batch_dim=True,
180180
)
181-
for ds_batch in bg:
182-
assert ds_batch['x'].shape == [bsize, xbsize]
181+
for ds_batch in bg2:
182+
assert list(ds_batch['foo'].shape) == [xbsize]

0 commit comments

Comments
 (0)