@@ -42,35 +42,57 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
42
42
43
43
44
44
@pytest .fixture (scope = 'module' )
45
- def sample_ds_2d ():
46
- shape = (50 , 100 )
47
- ds = xr .Dataset ({'foo' : (['y' , 'x' ], np .random .rand (* shape )),
48
- 'bar' : (['y' , 'x' ], np .random .randint (0 , 10 , shape ))},
45
+ def sample_ds_3d ():
46
+ shape = (10 , 50 , 100 )
47
+ ds = xr .Dataset ({'foo' : (['time' , ' y' , 'x' ], np .random .rand (* shape )),
48
+ 'bar' : (['time' , ' y' , 'x' ], np .random .randint (0 , 10 , shape ))},
49
49
{'x' : (['x' ], np .arange (shape [- 1 ])),
50
50
'y' : (['y' ], np .arange (shape [- 2 ]))})
51
51
return ds
52
52
53
53
54
54
@pytest .mark .parametrize ("bsize" , [5 , 10 ])
55
- def test_batch_2d ( sample_ds_2d , bsize ):
55
+ def test_batch_3d_1d_input ( sample_ds_3d , bsize ):
56
56
57
57
# first do the iteration over just one dimension
58
- bg = BatchGenerator (sample_ds_2d , input_dims = {'x' : bsize })
58
+ bg = BatchGenerator (sample_ds_3d , input_dims = {'x' : bsize })
59
59
for n , ds_batch in enumerate (bg ):
60
60
assert isinstance (ds_batch , xr .Dataset )
61
61
assert ds_batch .dims ['x' ] == bsize
62
- assert ds_batch .dims ['y' ] == sample_ds_2d .dims ['y' ]
62
+ # time and y should be collapsed into batch dimension
63
+ assert ds_batch .dims ['batch' ] == sample_ds_3d .dims ['y' ] * sample_ds_3d .dims ['time' ]
63
64
expected_slice = slice (bsize * n , bsize * (n + 1 ))
64
- ds_batch_expected = sample_ds_2d .isel (x = expected_slice )
65
+ ds_batch_expected = (sample_ds_3d .isel (x = expected_slice )
66
+ .stack (batch = ['y' , 'time' ])
67
+ .transpose ('batch' , 'x' ))
68
+ print (ds_batch )
69
+ print (ds_batch_expected )
65
70
assert ds_batch .equals (ds_batch_expected )
66
71
72
+ @pytest .mark .parametrize ("bsize" , [5 , 10 ])
73
+ def test_batch_3d_2d_input (sample_ds_3d , bsize ):
67
74
# now iterate over both
68
75
xbsize = 20
69
- bg = BatchGenerator (sample_ds_2d , input_dims = {'y' : bsize , 'x' : xbsize })
76
+ bg = BatchGenerator (sample_ds_3d , input_dims = {'y' : bsize , 'x' : xbsize })
70
77
for n , ds_batch in enumerate (bg ):
71
78
assert isinstance (ds_batch , xr .Dataset )
72
79
assert ds_batch .dims ['x' ] == xbsize
73
80
assert ds_batch .dims ['y' ] == bsize
74
81
# TODO? Is it worth it to try to reproduce the internal logic of the
75
82
# generator and verify that the slices are correct?
76
- assert (n + 1 )== ((sample_ds_2d .dims ['x' ]// xbsize ) * (sample_ds_2d .dims ['y' ]// bsize ))
83
+ assert (n + 1 )== ((sample_ds_3d .dims ['x' ]// xbsize ) * (sample_ds_3d .dims ['y' ]// bsize ))
84
+
85
+
86
+ @pytest .mark .parametrize ("bsize" , [5 , 10 ])
87
+ def test_batch_3d_2d_input_concat (sample_ds_3d , bsize ):
88
+ # now iterate over both
89
+ xbsize = 20
90
+ bg = BatchGenerator (sample_ds_3d , input_dims = {'y' : bsize , 'x' : xbsize },
91
+ concat_input_dims = True )
92
+ for n , ds_batch in enumerate (bg ):
93
+ assert isinstance (ds_batch , xr .Dataset )
94
+ assert ds_batch .dims ['x_input' ] == xbsize
95
+ assert ds_batch .dims ['y_input' ] == bsize
96
+ assert ds_batch .dims ['batch' ] == ((sample_ds_3d .dims ['x' ]// xbsize ) *
97
+ (sample_ds_3d .dims ['y' ]// bsize ) *
98
+ sample_ds_3d .dims ['time' ])
0 commit comments