6
6
7
7
@pytest .fixture (scope = 'module' )
8
8
def sample_ds_1d ():
9
- size = 100
9
+ size = 100
10
10
ds = xr .Dataset ({'foo' : (['x' ], np .random .rand (size )),
11
11
'bar' : (['x' ], np .random .randint (0 , 10 , size ))},
12
12
{'x' : (['x' ], np .arange (size ))})
13
13
return ds
14
14
15
+
15
16
# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
16
17
# Should we enforce that each batch size always has to be the same
17
18
@pytest .mark .parametrize ("bsize" , [5 , 10 ])
@@ -25,6 +26,7 @@ def test_batch_1d(sample_ds_1d, bsize):
25
26
ds_batch_expected = sample_ds_1d .isel (x = expected_slice )
26
27
assert ds_batch .equals (ds_batch_expected )
27
28
29
+
28
30
@pytest .mark .parametrize ("olap" , [1 , 4 ])
29
31
def test_batch_1d_overlap (sample_ds_1d , olap ):
30
32
bsize = 10
@@ -37,3 +39,38 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
37
39
expected_slice = slice (stride * n , stride * n + bsize )
38
40
ds_batch_expected = sample_ds_1d .isel (x = expected_slice )
39
41
assert ds_batch .equals (ds_batch_expected )
42
+
43
+
44
+ @pytest .fixture (scope = 'module' )
45
+ def sample_ds_2d ():
46
+ shape = (50 , 100 )
47
+ ds = xr .Dataset ({'foo' : (['y' , 'x' ], np .random .rand (* shape )),
48
+ 'bar' : (['y' , 'x' ], np .random .randint (0 , 10 , shape ))},
49
+ {'x' : (['x' ], np .arange (shape [- 1 ])),
50
+ 'y' : (['y' ], np .arange (shape [- 2 ]))})
51
+ return ds
52
+
53
+
54
+ @pytest .mark .parametrize ("bsize" , [5 , 10 ])
55
+ def test_batch_2d (sample_ds_2d , bsize ):
56
+
57
+ # first do the iteration over just one dimension
58
+ bg = BatchGenerator (sample_ds_2d , batch_sizes = {'x' : bsize })
59
+ for n , ds_batch in enumerate (bg ):
60
+ assert isinstance (ds_batch , xr .Dataset )
61
+ assert ds_batch .dims ['x' ] == bsize
62
+ assert ds_batch .dims ['y' ] == sample_ds_2d .dims ['y' ]
63
+ expected_slice = slice (bsize * n , bsize * (n + 1 ))
64
+ ds_batch_expected = sample_ds_2d .isel (x = expected_slice )
65
+ assert ds_batch .equals (ds_batch_expected )
66
+
67
+ # now iterate over both
68
+ xbsize = 20
69
+ bg = BatchGenerator (sample_ds_2d , batch_sizes = {'y' : bsize , 'x' : xbsize })
70
+ for n , ds_batch in enumerate (bg ):
71
+ assert isinstance (ds_batch , xr .Dataset )
72
+ assert ds_batch .dims ['x' ] == xbsize
73
+ assert ds_batch .dims ['y' ] == bsize
74
+ # TODO? Is it worth it to try to reproduce the internal logic of the
75
+ # generator and verify that the slices are correct?
76
+ assert (n + 1 )== ((sample_ds_2d .dims ['x' ]// xbsize ) * (sample_ds_2d .dims ['y' ]// bsize ))
0 commit comments