22
22
23
23
dask .config .set (scheduler = "sync" )
24
24
25
- try :
26
- # test against legacy xarray implementation
27
- xr .set_options (use_flox = False )
28
- except ValueError :
29
- pass
30
-
31
-
25
+ # test against legacy xarray implementation
26
+ # avoid some compilation overhead
27
+ xr .set_options (use_flox = False , use_numbagg = False )
32
28
tolerance64 = {"rtol" : 1e-15 , "atol" : 1e-18 }
33
29
np .random .seed (123 )
34
30
37
33
@pytest .mark .parametrize ("min_count" , [None , 1 , 3 ])
38
34
@pytest .mark .parametrize ("add_nan" , [True , False ])
39
35
@pytest .mark .parametrize ("skipna" , [True , False ])
40
- def test_xarray_reduce (skipna , add_nan , min_count , engine , reindex ):
36
+ def test_xarray_reduce (skipna , add_nan , min_count , engine_no_numba , reindex ):
37
+ engine = engine_no_numba
41
38
if skipna is False and min_count is not None :
42
39
pytest .skip ()
43
40
@@ -57,7 +54,13 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine, reindex):
57
54
58
55
expected = da .groupby ("labels" ).sum (skipna = skipna , min_count = min_count )
59
56
actual = xarray_reduce (
60
- da , "labels" , func = "sum" , skipna = skipna , min_count = min_count , engine = engine , reindex = reindex
57
+ da ,
58
+ "labels" ,
59
+ func = "sum" ,
60
+ skipna = skipna ,
61
+ min_count = min_count ,
62
+ engine = engine ,
63
+ reindex = reindex ,
61
64
)
62
65
assert_equal (expected , actual )
63
66
@@ -85,9 +88,10 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine, reindex):
85
88
# TODO: sort
86
89
@pytest .mark .parametrize ("pass_expected_groups" , [True , False ])
87
90
@pytest .mark .parametrize ("chunk" , (pytest .param (True , marks = requires_dask ), False ))
88
- def test_xarray_reduce_multiple_groupers (pass_expected_groups , chunk , engine ):
91
+ def test_xarray_reduce_multiple_groupers (pass_expected_groups , chunk , engine_no_numba ):
89
92
if chunk and pass_expected_groups is False :
90
93
pytest .skip ()
94
+ engine = engine_no_numba
91
95
92
96
arr = np .ones ((4 , 12 ))
93
97
labels = np .array (["a" , "a" , "c" , "c" , "c" , "b" , "b" , "c" , "c" , "b" , "b" , "f" ])
@@ -131,9 +135,10 @@ def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine):
131
135
132
136
@pytest .mark .parametrize ("pass_expected_groups" , [True , False ])
133
137
@pytest .mark .parametrize ("chunk" , (pytest .param (True , marks = requires_dask ), False ))
134
- def test_xarray_reduce_multiple_groupers_2 (pass_expected_groups , chunk , engine ):
138
+ def test_xarray_reduce_multiple_groupers_2 (pass_expected_groups , chunk , engine_no_numba ):
135
139
if chunk and pass_expected_groups is False :
136
140
pytest .skip ()
141
+ engine = engine_no_numba
137
142
138
143
arr = np .ones ((2 , 12 ))
139
144
labels = np .array (["a" , "a" , "c" , "c" , "c" , "b" , "b" , "c" , "c" , "b" , "b" , "f" ])
@@ -187,7 +192,8 @@ def test_validate_expected_groups(expected_groups):
187
192
188
193
@requires_cftime
189
194
@requires_dask
190
- def test_xarray_reduce_single_grouper (engine ):
195
+ def test_xarray_reduce_single_grouper (engine_no_numba ):
196
+ engine = engine_no_numba
191
197
# DataArray
192
198
ds = xr .Dataset (
193
199
{"Tair" : (("time" , "x" , "y" ), dask .array .ones ((36 , 205 , 275 ), chunks = (9 , - 1 , - 1 )))},
@@ -293,15 +299,17 @@ def test_rechunk_for_blockwise(inchunks, expected):
293
299
# TODO: dim=None, dim=Ellipsis, groupby unindexed dim
294
300
295
301
296
- def test_groupby_duplicate_coordinate_labels (engine ):
302
+ def test_groupby_duplicate_coordinate_labels (engine_no_numba ):
303
+ engine = engine_no_numba
297
304
# fix for http://stackoverflow.com/questions/38065129
298
305
array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 1 , 2 ])])
299
306
expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
300
307
actual = xarray_reduce (array , array .x , func = "sum" , engine = engine )
301
308
assert_equal (expected , actual )
302
309
303
310
304
- def test_multi_index_groupby_sum (engine ):
311
+ def test_multi_index_groupby_sum (engine_no_numba ):
312
+ engine = engine_no_numba
305
313
# regression test for xarray GH873
306
314
ds = xr .Dataset (
307
315
{"foo" : (("x" , "y" , "z" ), np .ones ((3 , 4 , 2 )))},
@@ -327,7 +335,8 @@ def test_multi_index_groupby_sum(engine):
327
335
328
336
329
337
@pytest .mark .parametrize ("chunks" , (None , pytest .param (2 , marks = requires_dask )))
330
- def test_xarray_groupby_bins (chunks , engine ):
338
+ def test_xarray_groupby_bins (chunks , engine_no_numba ):
339
+ engine = engine_no_numba
331
340
array = xr .DataArray ([1 , 1 , 1 , 1 , 1 ], dims = "x" )
332
341
labels = xr .DataArray ([1 , 1.5 , 1.9 , 2 , 3 ], dims = "x" , name = "labels" )
333
342
@@ -495,11 +504,11 @@ def test_alignment_error():
495
504
@pytest .mark .parametrize ("dtype_out" , [np .float64 , "float64" , np .dtype ("float64" )])
496
505
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
497
506
@pytest .mark .parametrize ("chunk" , (pytest .param (True , marks = requires_dask ), False ))
498
- def test_dtype (add_nan , chunk , dtype , dtype_out , engine ):
499
- if engine == "numbagg" :
507
+ def test_dtype (add_nan , chunk , dtype , dtype_out , engine_no_numba ):
508
+ if engine_no_numba == "numbagg" :
500
509
# https://github.com/numbagg/numbagg/issues/121
501
510
pytest .skip ()
502
-
511
+ engine = engine_no_numba
503
512
xp = dask .array if chunk else np
504
513
data = xp .linspace (0 , 1 , 48 , dtype = dtype ).reshape ((4 , 12 ))
505
514
@@ -707,7 +716,7 @@ def test_multiple_quantiles(q, chunk, by_ndim, skipna):
707
716
da = xr .DataArray (array , dims = ("x" , * dims ))
708
717
by = xr .DataArray (labels , dims = dims , name = "by" )
709
718
710
- actual = xarray_reduce (da , by , func = "quantile" , skipna = skipna , q = q )
719
+ actual = xarray_reduce (da , by , func = "quantile" , skipna = skipna , q = q , engine = "flox" )
711
720
with xr .set_options (use_flox = False ):
712
721
expected = da .groupby (by ).quantile (q , skipna = skipna )
713
722
xr .testing .assert_allclose (expected , actual )
0 commit comments