Skip to content

Commit 44f3851

Browse files
committed
Add duck array support
1 parent e405517 commit 44f3851

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

flox/aggregate_flox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
1313
# assumes input is sorted, which I do in core._prepare_for_flox
1414
aux = group_idx
1515

16-
flag = np.concatenate(([True], aux[1:] != aux[:-1]))
16+
flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1]))
1717
uniques = aux[flag]
1818
(inv_idx,) = flag.nonzero()
1919

@@ -25,7 +25,7 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
2525
if out is None:
2626
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
2727

28-
if (len(uniques) == size) and (uniques == np.arange(size)).all():
28+
if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
2929
# The previous version of this if condition
3030
# ((uniques[1:] - uniques[:-1]) == 1).all():
3131
# does not work when group_idx is [1, 2] for e.g.

flox/aggregations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def generic_aggregate(
4646
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
4747
)
4848

49+
group_idx = np.asarray(group_idx, like=array)
50+
4951
return method(
5052
group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
5153
)

flox/xrutils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
9898

9999

100100
def isnull(data):
101-
data = np.asarray(data)
101+
if not is_duck_array(data):
102+
data = np.asarray(data)
102103
scalar_type = data.dtype.type
103104
if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
104105
# datetime types use NaT for null

0 commit comments

Comments
 (0)