Skip to content

Commit 5fa31f3

Browse files
committed
Allow method="cohorts" when grouping by dask array
This allows us to run the `bitmask` calculation remotely and send that back.
1 parent f8f34b9 commit 5fa31f3

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

flox/core.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def chunk_unique(labels, slicer, nlabels, label_is_present=None):
329329
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
330330
cols_array = np.concatenate(cols)
331331

332-
return make_bitmask(rows_array, cols_array)
332+
return make_bitmask(rows_array, cols_array), nlabels, ilabels
333333

334334

335335
# @memoize
@@ -362,8 +362,17 @@ def find_group_cohorts(
362362
cohorts: dict_values
363363
Iterable of cohorts
364364
"""
365-
# To do this, we must have values in memory so casting to numpy should be safe
366-
labels = np.asarray(labels)
365+
if not is_duck_array(labels):
366+
labels = np.asarray(labels)
367+
368+
if is_duck_dask_array(labels):
369+
import dask
370+
371+
((bitmask, nlabels, ilabels),) = dask.compute(
372+
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks)
373+
)
374+
else:
375+
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks)
367376

368377
shape = tuple(sum(c) for c in chunks)
369378
nchunks = math.prod(len(c) for c in chunks)
@@ -2409,9 +2418,6 @@ def groupby_reduce(
24092418
"Try engine='numpy' or engine='numba' instead."
24102419
)
24112420

2412-
if method == "cohorts" and any_by_dask:
2413-
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
2414-
24152421
reindex = _validate_reindex(
24162422
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
24172423
)
@@ -2443,6 +2449,12 @@ def groupby_reduce(
24432449
# can't do it if we are grouping by dask array but don't have expected_groups
24442450
any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups))
24452451
)
2452+
2453+
if method == "cohorts" and not factorize_early:
2454+
raise ValueError(
2455+
"method='cohorts' can only be used when grouping by dask arrays if `expected_groups` is provided."
2456+
)
2457+
24462458
expected_: pd.RangeIndex | None
24472459
if factorize_early:
24482460
bys, final_groups, grp_shape = _factorize_multiple(

0 commit comments

Comments
 (0)