4
4
5
5
from .xrutils import isnull
6
6
7
+ import cupy as cp
8
+ import cupyx
9
+
10
+ cupy_array_type = (cp .ndarray ,)
11
+
12
+ cupy_ops = {
13
+ np .add : cupyx .scatter_add ,
14
+ np .minimum : cupyx .scatter_min ,
15
+ np .maximum : cupyx .scatter_max ,
16
+ }
17
+
18
+
7
19
def _prepare_for_flox (group_idx , array ):
8
20
"""
9
21
Sort the input array once to save time.
@@ -24,7 +36,9 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
24
36
most of this code is from shoyer's gist
25
37
https://gist.github.com/shoyer/f538ac78ae904c936844
26
38
"""
27
- # assumes input is sorted, which I do in core._prepare_for_flox
39
+ # For numpy arrays, assumes input is sorted, which I do in _prepare_for_flox
40
+ # For cupy arrays, sorting is not needed
41
+
28
42
aux = group_idx
29
43
30
44
flag = np .concatenate ((np .array ([True ], like = array ), aux [1 :] != aux [:- 1 ]))
@@ -37,7 +51,12 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
37
51
dtype = array .dtype
38
52
39
53
if out is None :
40
- out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
54
+ out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype , like = array )
55
+
56
+ if isinstance (array , cupy_array_type ):
57
+ op = cupy_ops [op ]
58
+ op (out , group_idx , array )
59
+ return out
41
60
42
61
if (len (uniques ) == size ) and (uniques == np .arange (size , like = array )).all ():
43
62
# The previous version of this if condition
0 commit comments