@@ -18,6 +18,29 @@ def sample_ds_1d():
18
18
return ds
19
19
20
20
21
+ @pytest .fixture (scope = 'module' )
22
+ def sample_ds_3d ():
23
+ shape = (10 , 50 , 100 )
24
+ ds = xr .Dataset (
25
+ {
26
+ 'foo' : (['time' , 'y' , 'x' ], np .random .rand (* shape )),
27
+ 'bar' : (['time' , 'y' , 'x' ], np .random .randint (0 , 10 , shape )),
28
+ },
29
+ {
30
+ 'x' : (['x' ], np .arange (shape [- 1 ])),
31
+ 'y' : (['y' ], np .arange (shape [- 2 ])),
32
+ },
33
+ )
34
+ return ds
35
+
36
+
37
+ def test_constructor_coerces_to_dataset ():
38
+ da = xr .DataArray (np .random .rand (10 ), dims = 'x' , name = 'foo' )
39
+ bg = BatchGenerator (da , input_dims = {'x' : 2 })
40
+ assert isinstance (bg .ds , xr .Dataset )
41
+ assert bg .ds .equals (da .to_dataset ())
42
+
43
+
21
44
# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
22
45
# Should we enforce that each batch size always has to be the same
23
46
@pytest .mark .parametrize ('bsize' , [5 , 10 ])
@@ -86,22 +109,6 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
86
109
assert ds_batch .equals (ds_batch_expected )
87
110
88
111
89
- @pytest .fixture (scope = 'module' )
90
- def sample_ds_3d ():
91
- shape = (10 , 50 , 100 )
92
- ds = xr .Dataset (
93
- {
94
- 'foo' : (['time' , 'y' , 'x' ], np .random .rand (* shape )),
95
- 'bar' : (['time' , 'y' , 'x' ], np .random .randint (0 , 10 , shape )),
96
- },
97
- {
98
- 'x' : (['x' ], np .arange (shape [- 1 ])),
99
- 'y' : (['y' ], np .arange (shape [- 2 ])),
100
- },
101
- )
102
- return ds
103
-
104
-
105
112
@pytest .mark .parametrize ('bsize' , [5 , 10 ])
106
113
def test_batch_3d_1d_input (sample_ds_3d , bsize ):
107
114
@@ -160,3 +167,22 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
160
167
* (sample_ds_3d .dims ['y' ] // bsize )
161
168
* sample_ds_3d .dims ['time' ]
162
169
)
170
+
171
+
172
+ def test_preload_batch_false (sample_ds_1d ):
173
+ sample_ds_1d_dask = sample_ds_1d .chunk ({'x' : 2 })
174
+ bg = BatchGenerator (
175
+ sample_ds_1d_dask , input_dims = {'x' : 2 }, preload_batch = False
176
+ )
177
+ assert bg .preload_batch is False
178
+ for ds_batch in bg :
179
+ assert isinstance (ds_batch , xr .Dataset )
180
+ assert ds_batch .chunks
181
+
182
+
183
+ def test_preload_batch_true (sample_ds_1d ):
184
+ bg = BatchGenerator (sample_ds_1d , input_dims = {'x' : 2 }, preload_batch = True )
185
+ assert bg .preload_batch is True
186
+ for ds_batch in bg :
187
+ assert isinstance (ds_batch , xr .Dataset )
188
+ assert not ds_batch .chunks
0 commit comments