@@ -26,6 +26,40 @@ def test_batch_1d(sample_ds_1d, bsize):
26
26
ds_batch_expected = sample_ds_1d .isel (x = expected_slice )
27
27
assert ds_batch .equals (ds_batch_expected )
28
28
29
+ @pytest .mark .parametrize ("bsize" , [5 , 10 ])
30
+ def test_batch_1d_concat (sample_ds_1d , bsize ):
31
+ bg = BatchGenerator (sample_ds_1d , input_dims = {'x' : bsize },
32
+ concat_input_dims = True )
33
+ for n , ds_batch in enumerate (bg ):
34
+ assert isinstance (ds_batch , xr .Dataset )
35
+ assert ds_batch .dims ['x_input' ] == bsize
36
+ assert ds_batch .dims ['input_batch' ] == sample_ds_1d .dims ['x' ]// bsize
37
+ assert 'x' in ds_batch .coords
38
+
39
+ @pytest .mark .parametrize ("bsize" , [5 , 10 ])
40
+ def test_batch_1d_no_coordinate (sample_ds_1d , bsize ):
41
+ # fix for #3
42
+ ds_dropped = sample_ds_1d .drop ('x' )
43
+ bg = BatchGenerator (ds_dropped , input_dims = {'x' : bsize })
44
+ for n , ds_batch in enumerate (bg ):
45
+ assert isinstance (ds_batch , xr .Dataset )
46
+ assert ds_batch .dims ['x' ] == bsize
47
+ expected_slice = slice (bsize * n , bsize * (n + 1 ))
48
+ ds_batch_expected = ds_dropped .isel (x = expected_slice )
49
+ assert ds_batch .equals (ds_batch_expected )
50
+
51
+ @pytest .mark .parametrize ("bsize" , [5 , 10 ])
52
+ def test_batch_1d_concat_no_coordinate (sample_ds_1d , bsize ):
53
+ # fix for #3
54
+ ds_dropped = sample_ds_1d .drop ('x' )
55
+ bg = BatchGenerator (ds_dropped , input_dims = {'x' : bsize },
56
+ concat_input_dims = True )
57
+ for n , ds_batch in enumerate (bg ):
58
+ assert isinstance (ds_batch , xr .Dataset )
59
+ assert ds_batch .dims ['x_input' ] == bsize
60
+ assert ds_batch .dims ['input_batch' ] == sample_ds_1d .dims ['x' ]// bsize
61
+ assert 'x' not in ds_batch .coords
62
+
29
63
30
64
@pytest .mark .parametrize ("olap" , [1 , 4 ])
31
65
def test_batch_1d_overlap (sample_ds_1d , olap ):
0 commit comments