File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -163,20 +163,20 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
163
163
164
164
165
165
@pytest .mark .parametrize ('bsize' , [5 , 10 ])
166
- def test_batch_3d_squeeze_batch_dim ( sample_ds_3d , bsize ):
166
+ def test_batch_1d_squeeze_batch_dim ( sample_ds_1d , bsize ):
167
167
xbsize = 20
168
168
bg = BatchGenerator (
169
- sample_ds_3d ,
170
- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
169
+ sample_ds_1d ,
170
+ input_dims = {'x' : xbsize },
171
171
squeeze_batch_dim = False ,
172
172
)
173
173
for ds_batch in bg :
174
- assert ds_batch ['x ' ].shape == [1 , bsize , xbsize ]
174
+ assert list ( ds_batch ['foo ' ].shape ) == [1 , xbsize ]
175
175
176
176
bg2 = BatchGenerator (
177
- sample_ds_3d ,
178
- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
177
+ sample_ds_1d ,
178
+ input_dims = {'x' : xbsize },
179
179
squeeze_batch_dim = True ,
180
180
)
181
- for ds_batch in bg :
182
- assert ds_batch ['x ' ].shape == [bsize , xbsize ]
181
+ for ds_batch in bg2 :
182
+ assert list ( ds_batch ['foo ' ].shape ) == [xbsize ]
You can’t perform that action at this time.
0 commit comments