@@ -133,7 +133,6 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
133
133
134
134
@pytest .mark .parametrize ('bsize' , [5 , 10 ])
135
135
def test_batch_3d_1d_input (sample_ds_3d , bsize ):
136
-
137
136
# first do the iteration over just one dimension
138
137
bg = BatchGenerator (sample_ds_3d , input_dims = {'x' : bsize })
139
138
for n , ds_batch in enumerate (bg ):
@@ -164,8 +163,19 @@ def test_batch_3d_2d_input(sample_ds_3d, bsize):
164
163
assert isinstance (ds_batch , xr .Dataset )
165
164
assert ds_batch .dims ['x' ] == xbsize
166
165
assert ds_batch .dims ['y' ] == bsize
167
- # TODO? Is it worth it to try to reproduce the internal logic of the
168
- # generator and verify that the slices are correct?
166
+ yn , xn = np .unravel_index (
167
+ n ,
168
+ (
169
+ (sample_ds_3d .dims ['y' ] // bsize ),
170
+ (sample_ds_3d .dims ['x' ] // xbsize ),
171
+ ),
172
+ )
173
+ expected_xslice = slice (xbsize * xn , xbsize * (xn + 1 ))
174
+ expected_yslice = slice (bsize * yn , bsize * (yn + 1 ))
175
+ ds_batch_expected = sample_ds_3d .isel (
176
+ x = expected_xslice , y = expected_yslice
177
+ )
178
+ xr .testing .assert_equal (ds_batch_expected , ds_batch )
169
179
assert (n + 1 ) == (
170
180
(sample_ds_3d .dims ['x' ] // xbsize ) * (sample_ds_3d .dims ['y' ] // bsize )
171
181
)
0 commit comments