@@ -178,31 +178,53 @@ def test_groupby_reduce(
178
178
assert_equal (expected_result , result )
179
179
180
180
181
- def gen_array_by (size , func ):
182
- by = np .ones (size [- 1 ])
183
- rng = np .random .default_rng (12345 )
181
+ def maybe_skip_cupy (array_module , func , engine ):
182
+ if array_module is np :
183
+ return
184
+
185
+ import cupy
186
+
187
+ assert array_module is cupy
188
+
189
+ if engine == "numba" :
190
+ pytest .skip ()
191
+
192
+ if engine == "numpy" and ("prod" in func or "first" in func or "last" in func ):
193
+ pytest .xfail ()
194
+ elif engine == "flox" and not (
195
+ "sum" in func or "mean" in func or "std" in func or "var" in func
196
+ ):
197
+ pytest .xfail ()
198
+
199
+
200
+ def gen_array_by (size , func , array_module ):
201
+ xp = array_module
202
+ by = xp .ones (size [- 1 ])
203
+ rng = xp .random .default_rng (12345 )
184
204
array = rng .random (size )
185
205
if "nan" in func and "nanarg" not in func :
186
- array [[1 , 4 , 5 ], ...] = np .nan
206
+ array [[1 , 4 , 5 ], ...] = xp .nan
187
207
elif "nanarg" in func and len (size ) > 1 :
188
- array [[1 , 4 , 5 ], 1 ] = np .nan
208
+ array [[1 , 4 , 5 ], 1 ] = xp .nan
189
209
if func in ["any" , "all" ]:
190
210
array = array > 0.5
191
211
return array , by
192
212
193
213
194
- @pytest .mark .parametrize ("chunks" , [None , - 1 , 3 , 4 ])
195
214
@pytest .mark .parametrize ("nby" , [1 , 2 , 3 ])
196
215
@pytest .mark .parametrize ("size" , ((12 ,), (12 , 9 )))
197
- @pytest .mark .parametrize ("add_nan_by " , [True , False ])
216
+ @pytest .mark .parametrize ("chunks " , [None , - 1 , 3 , 4 ])
198
217
@pytest .mark .parametrize ("func" , ALL_FUNCS )
199
- def test_groupby_reduce_all (nby , size , chunks , func , add_nan_by , engine ):
218
+ @pytest .mark .parametrize ("add_nan_by" , [True , False ])
219
+ def test_groupby_reduce_all (nby , size , chunks , func , add_nan_by , engine , array_module ):
200
220
if chunks is not None and not has_dask :
201
221
pytest .skip ()
202
222
if "arg" in func and engine == "flox" :
203
223
pytest .skip ()
204
224
205
- array , by = gen_array_by (size , func )
225
+ maybe_skip_cupy (array_module , func , engine )
226
+
227
+ array , by = gen_array_by (size , func , array_module )
206
228
if chunks :
207
229
array = dask .array .from_array (array , chunks = chunks )
208
230
by = (by ,) * nby
0 commit comments