@@ -329,7 +329,7 @@ def chunk_unique(labels, slicer, nlabels, label_is_present=None):
329
329
rows_array = np .repeat (np .arange (nchunks ), tuple (len (col ) for col in cols ))
330
330
cols_array = np .concatenate (cols )
331
331
332
- return make_bitmask (rows_array , cols_array )
332
+ return make_bitmask (rows_array , cols_array ), nlabels , ilabels
333
333
334
334
335
335
# @memoize
@@ -362,8 +362,17 @@ def find_group_cohorts(
362
362
cohorts: dict_values
363
363
Iterable of cohorts
364
364
"""
365
- # To do this, we must have values in memory so casting to numpy should be safe
366
- labels = np .asarray (labels )
365
+ if not is_duck_array (labels ):
366
+ labels = np .asarray (labels )
367
+
368
+ if is_duck_dask_array (labels ):
369
+ import dask
370
+
371
+ ((bitmask , nlabels , ilabels ),) = dask .compute (
372
+ dask .delayed (_compute_label_chunk_bitmask )(labels , chunks )
373
+ )
374
+ else :
375
+ bitmask , nlabels , ilabels = _compute_label_chunk_bitmask (labels , chunks )
367
376
368
377
shape = tuple (sum (c ) for c in chunks )
369
378
nchunks = math .prod (len (c ) for c in chunks )
@@ -2409,9 +2418,6 @@ def groupby_reduce(
2409
2418
"Try engine='numpy' or engine='numba' instead."
2410
2419
)
2411
2420
2412
- if method == "cohorts" and any_by_dask :
2413
- raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
2414
-
2415
2421
reindex = _validate_reindex (
2416
2422
reindex , func , method , expected_groups , any_by_dask , is_duck_dask_array (array )
2417
2423
)
@@ -2443,6 +2449,12 @@ def groupby_reduce(
2443
2449
# can't do it if we are grouping by dask array but don't have expected_groups
2444
2450
any (is_dask and ex_ is None for is_dask , ex_ in zip (by_is_dask , expected_groups ))
2445
2451
)
2452
+
2453
+ if method == "cohorts" and not factorize_early :
2454
+ raise ValueError (
2455
+ "method='cohorts' can only be used when grouping by dask arrays if `expected_groups` is provided."
2456
+ )
2457
+
2446
2458
expected_ : pd .RangeIndex | None
2447
2459
if factorize_early :
2448
2460
bys , final_groups , grp_shape = _factorize_multiple (
0 commit comments