@@ -214,34 +214,8 @@ def slices_from_chunks(chunks):
214
214
return product (* slices )
215
215
216
216
217
- @memoize
218
- def find_group_cohorts (labels , chunks , merge : bool = True ) -> dict :
219
- """
220
- Finds groups labels that occur together aka "cohorts"
221
-
222
- If available, results are cached in a 1MB cache managed by `cachey`.
223
- This allows us to be quick when repeatedly calling groupby_reduce
224
- for arrays with the same chunking (e.g. an xarray Dataset).
225
-
226
- Parameters
227
- ----------
228
- labels : np.ndarray
229
- mD Array of integer group codes, factorized so that -1
230
- represents NaNs.
231
- chunks : tuple
232
- chunks of the array being reduced
233
- merge : bool, optional
234
- Attempt to merge cohorts when one cohort's chunks are a subset
235
- of another cohort's chunks.
236
-
237
- Returns
238
- -------
239
- cohorts: dict_values
240
- Iterable of cohorts
241
- """
242
- # To do this, we must have values in memory so casting to numpy should be safe
243
- labels = np .asarray (labels )
244
-
217
+ def _compute_label_chunk_bitmask (labels , chunks ):
218
+ assert isinstance (labels , np .ndarray )
245
219
shape = tuple (sum (c ) for c in chunks )
246
220
nchunks = math .prod (len (c ) for c in chunks )
247
221
@@ -271,6 +245,47 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
271
245
cols_array = np .concatenate (cols )
272
246
data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows_array .shape )
273
247
bitmask = csc_array ((data , (rows_array , cols_array )), dtype = bool , shape = (nchunks , nlabels ))
248
+
249
+ return bitmask , nlabels , ilabels
250
+
251
+
252
+ @memoize
253
+ def find_group_cohorts (labels , chunks , merge : bool = True ) -> dict :
254
+ """
255
+ Finds groups labels that occur together aka "cohorts"
256
+
257
+ If available, results are cached in a 1MB cache managed by `cachey`.
258
+ This allows us to be quick when repeatedly calling groupby_reduce
259
+ for arrays with the same chunking (e.g. an xarray Dataset).
260
+
261
+ Parameters
262
+ ----------
263
+ labels : np.ndarray
264
+ mD Array of integer group codes, factorized so that -1
265
+ represents NaNs.
266
+ chunks : tuple
267
+ chunks of the array being reduced
268
+ merge : bool, optional
269
+ Attempt to merge cohorts when one cohort's chunks are a subset
270
+ of another cohort's chunks.
271
+
272
+ Returns
273
+ -------
274
+ cohorts: dict_values
275
+ Iterable of cohorts
276
+ """
277
+ if not is_duck_array (labels ):
278
+ labels = np .asarray (labels )
279
+
280
+ if is_duck_dask_array (labels ):
281
+ import dask
282
+
283
+ ((bitmask , nlabels , ilabels ),) = dask .compute (
284
+ dask .delayed (_compute_label_chunk_bitmask )(labels , chunks )
285
+ )
286
+ else :
287
+ bitmask , nlabels , ilabels = _compute_label_chunk_bitmask (labels , chunks )
288
+
274
289
label_chunks = {
275
290
lab : bitmask .indices [slice (bitmask .indptr [lab ], bitmask .indptr [lab + 1 ])]
276
291
for lab in range (nlabels )
@@ -2039,9 +2054,6 @@ def groupby_reduce(
2039
2054
"Try engine='numpy' or engine='numba' instead."
2040
2055
)
2041
2056
2042
- if method == "cohorts" and any_by_dask :
2043
- raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
2044
-
2045
2057
reindex = _validate_reindex (
2046
2058
reindex , func , method , expected_groups , any_by_dask , is_duck_dask_array (array )
2047
2059
)
@@ -2076,6 +2088,12 @@ def groupby_reduce(
2076
2088
# can't do it if we are grouping by dask array but don't have expected_groups
2077
2089
any (is_dask and ex_ is None for is_dask , ex_ in zip (by_is_dask , expected_groups ))
2078
2090
)
2091
+
2092
+ if method == "cohorts" and not factorize_early :
2093
+ raise ValueError (
2094
+ "method='cohorts' can only be used when grouping by dask arrays if `expected_groups` is provided."
2095
+ )
2096
+
2079
2097
if factorize_early :
2080
2098
bys , final_groups , grp_shape = _factorize_multiple (
2081
2099
bys ,
0 commit comments