Skip to content

Commit 7882591

Browse files
committed
Allow method="cohorts" when grouping by dask array
This allows us to run the `bitmask` calculation remotely and send that back.
1 parent 12cbef9 commit 7882591

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

flox/core.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -214,34 +214,8 @@ def slices_from_chunks(chunks):
214214
return product(*slices)
215215

216216

217-
@memoize
218-
def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
219-
"""
220-
Finds groups labels that occur together aka "cohorts"
221-
222-
If available, results are cached in a 1MB cache managed by `cachey`.
223-
This allows us to be quick when repeatedly calling groupby_reduce
224-
for arrays with the same chunking (e.g. an xarray Dataset).
225-
226-
Parameters
227-
----------
228-
labels : np.ndarray
229-
mD Array of integer group codes, factorized so that -1
230-
represents NaNs.
231-
chunks : tuple
232-
chunks of the array being reduced
233-
merge : bool, optional
234-
Attempt to merge cohorts when one cohort's chunks are a subset
235-
of another cohort's chunks.
236-
237-
Returns
238-
-------
239-
cohorts: dict_values
240-
Iterable of cohorts
241-
"""
242-
# To do this, we must have values in memory so casting to numpy should be safe
243-
labels = np.asarray(labels)
244-
217+
def _compute_label_chunk_bitmask(labels, chunks):
218+
assert isinstance(labels, np.ndarray)
245219
shape = tuple(sum(c) for c in chunks)
246220
nchunks = math.prod(len(c) for c in chunks)
247221

@@ -271,6 +245,47 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
271245
cols_array = np.concatenate(cols)
272246
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
273247
bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
248+
249+
return bitmask, nlabels, ilabels
250+
251+
252+
@memoize
253+
def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
254+
"""
255+
Finds groups labels that occur together aka "cohorts"
256+
257+
If available, results are cached in a 1MB cache managed by `cachey`.
258+
This allows us to be quick when repeatedly calling groupby_reduce
259+
for arrays with the same chunking (e.g. an xarray Dataset).
260+
261+
Parameters
262+
----------
263+
labels : np.ndarray
264+
mD Array of integer group codes, factorized so that -1
265+
represents NaNs.
266+
chunks : tuple
267+
chunks of the array being reduced
268+
merge : bool, optional
269+
Attempt to merge cohorts when one cohort's chunks are a subset
270+
of another cohort's chunks.
271+
272+
Returns
273+
-------
274+
cohorts: dict_values
275+
Iterable of cohorts
276+
"""
277+
if not is_duck_array(labels):
278+
labels = np.asarray(labels)
279+
280+
if is_duck_dask_array(labels):
281+
import dask
282+
283+
((bitmask, nlabels, ilabels),) = dask.compute(
284+
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks)
285+
)
286+
else:
287+
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks)
288+
274289
label_chunks = {
275290
lab: bitmask.indices[slice(bitmask.indptr[lab], bitmask.indptr[lab + 1])]
276291
for lab in range(nlabels)
@@ -2039,9 +2054,6 @@ def groupby_reduce(
20392054
"Try engine='numpy' or engine='numba' instead."
20402055
)
20412056

2042-
if method == "cohorts" and any_by_dask:
2043-
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
2044-
20452057
reindex = _validate_reindex(
20462058
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
20472059
)
@@ -2076,6 +2088,12 @@ def groupby_reduce(
20762088
# can't do it if we are grouping by dask array but don't have expected_groups
20772089
any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups))
20782090
)
2091+
2092+
if method == "cohorts" and not factorize_early:
2093+
raise ValueError(
2094+
"method='cohorts' can only be used when grouping by dask arrays if `expected_groups` is provided."
2095+
)
2096+
20792097
if factorize_early:
20802098
bys, final_groups, grp_shape = _factorize_multiple(
20812099
bys,

0 commit comments

Comments
 (0)