Skip to content

Commit e5a60c0

Browse files
committed
WIP cupy support for engine="flox"
1 parent 2e6d385 commit e5a60c0

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

flox/aggregate_flox.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
from .xrutils import isnull
66

7+
import cupy as cp
8+
import cupyx
9+
10+
cupy_array_type = (cp.ndarray,)
11+
12+
cupy_ops = {
13+
np.add: cupyx.scatter_add,
14+
np.minimum: cupyx.scatter_min,
15+
np.maximum: cupyx.scatter_max,
16+
}
17+
18+
719
def _prepare_for_flox(group_idx, array):
820
"""
921
Sort the input array once to save time.
@@ -24,7 +36,9 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
2436
most of this code is from shoyer's gist
2537
https://gist.github.com/shoyer/f538ac78ae904c936844
2638
"""
27-
# assumes input is sorted, which I do in core._prepare_for_flox
39+
# For numpy arrays, assumes input is sorted, which I do in _prepare_for_flox
40+
# For cupy arrays, sorting is not needed
41+
2842
aux = group_idx
2943

3044
flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1]))
@@ -37,7 +51,12 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
3751
dtype = array.dtype
3852

3953
if out is None:
40-
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
54+
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype, like=array)
55+
56+
if isinstance(array, cupy_array_type):
57+
op = cupy_ops[op]
58+
op(out, group_idx, array)
59+
return out
4160

4261
if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
4362
# The previous version of this if condition

0 commit comments

Comments
 (0)