@@ -189,6 +189,46 @@ def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
189
189
assert list (ds_batch ['foo' ].shape ) == [xbsize ]
190
190
191
191
192
+ @pytest .mark .parametrize ('bsize' , [5 , 10 ])
193
+ def test_batch_3d_squeeze_batch_dim (sample_ds_3d , bsize ):
194
+ xbsize = 20
195
+ bg = BatchGenerator (
196
+ sample_ds_3d ,
197
+ input_dims = {'y' : bsize , 'x' : xbsize },
198
+ squeeze_batch_dim = False ,
199
+ )
200
+ for ds_batch in bg :
201
+ assert list (ds_batch ['foo' ].shape ) == [10 , bsize , xbsize ]
202
+
203
+ bg2 = BatchGenerator (
204
+ sample_ds_3d ,
205
+ input_dims = {'y' : bsize , 'x' : xbsize },
206
+ squeeze_batch_dim = True ,
207
+ )
208
+ for ds_batch in bg2 :
209
+ assert list (ds_batch ['foo' ].shape ) == [10 , bsize , xbsize ]
210
+
211
+
212
+ @pytest .mark .parametrize ('bsize' , [5 , 10 ])
213
+ def test_batch_3d_squeeze_batch_dim2 (sample_ds_3d , bsize ):
214
+ xbsize = 20
215
+ bg = BatchGenerator (
216
+ sample_ds_3d ,
217
+ input_dims = {'x' : xbsize },
218
+ squeeze_batch_dim = False ,
219
+ )
220
+ for ds_batch in bg :
221
+ assert list (ds_batch ['foo' ].shape ) == [500 , xbsize ]
222
+
223
+ bg2 = BatchGenerator (
224
+ sample_ds_3d ,
225
+ input_dims = {'x' : xbsize },
226
+ squeeze_batch_dim = True ,
227
+ )
228
+ for ds_batch in bg2 :
229
+ assert list (ds_batch ['foo' ].shape ) == [500 , xbsize ]
230
+
231
+
192
232
def test_preload_batch_false (sample_ds_1d ):
193
233
sample_ds_1d_dask = sample_ds_1d .chunk ({'x' : 2 })
194
234
bg = BatchGenerator (
0 commit comments