@@ -248,12 +248,22 @@ def slices_from_chunks(chunks):
248
248
249
249
250
250
def _compute_label_chunk_bitmask (labels , chunks , nlabels ):
251
+ def make_bitmask (rows , cols ):
252
+ data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows .shape )
253
+ return csc_array ((data , (rows , cols )), dtype = bool , shape = (nchunks , nlabels ))
254
+
251
255
assert isinstance (labels , np .ndarray )
252
256
shape = tuple (sum (c ) for c in chunks )
253
257
nchunks = math .prod (len (c ) for c in chunks )
254
258
255
- labels = np .broadcast_to (labels , shape [- labels .ndim :])
259
+ # Shortcut for 1D with size-1 chunks
260
+ if shape == (nchunks ,):
261
+ rows_array = np .arange (nchunks )
262
+ cols_array = labels
263
+ mask = labels >= 0
264
+ return make_bitmask (rows_array [mask ], cols_array [mask ])
256
265
266
+ labels = np .broadcast_to (labels , shape [- labels .ndim :])
257
267
cols = []
258
268
# Add one to handle the -1 sentinel value
259
269
label_is_present = np .zeros ((nlabels + 1 ,), dtype = bool )
@@ -272,10 +282,8 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
272
282
label_is_present [:] = False
273
283
rows_array = np .repeat (np .arange (nchunks ), tuple (len (col ) for col in cols ))
274
284
cols_array = np .concatenate (cols )
275
- data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows_array .shape )
276
- bitmask = csc_array ((data , (rows_array , cols_array )), dtype = bool , shape = (nchunks , nlabels ))
277
285
278
- return bitmask
286
+ return make_bitmask ( rows_array , cols_array )
279
287
280
288
281
289
# @memoize
@@ -312,13 +320,18 @@ def find_group_cohorts(
312
320
labels = np .asarray (labels )
313
321
314
322
shape = tuple (sum (c ) for c in chunks )
323
+ nchunks = math .prod (len (c ) for c in chunks )
315
324
316
325
# assumes that `labels` are factorized
317
326
if expected_groups is None :
318
327
nlabels = labels .max () + 1
319
328
else :
320
329
nlabels = expected_groups [- 1 ] + 1
321
330
331
+ # 1. Single chunk, blockwise always
332
+ if nchunks == 1 :
333
+ return "blockwise" , {(0 ,): list (range (nlabels ))}
334
+
322
335
labels = np .broadcast_to (labels , shape [- labels .ndim :])
323
336
bitmask = _compute_label_chunk_bitmask (labels , chunks , nlabels )
324
337
@@ -346,21 +359,21 @@ def invert(x) -> tuple[np.ndarray, ...]:
346
359
347
360
chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
348
361
349
- # 1 . Every group is contained to one block, use blockwise here.
362
+ # 2 . Every group is contained to one block, use blockwise here.
350
363
if bitmask .shape [CHUNK_AXIS ] == 1 or (chunks_per_label == 1 ).all ():
351
364
logger .info ("find_group_cohorts: blockwise is preferred." )
352
365
return "blockwise" , chunks_cohorts
353
366
354
- # 2 . Perfectly chunked so there is only a single cohort
367
+ # 3 . Perfectly chunked so there is only a single cohort
355
368
if len (chunks_cohorts ) == 1 :
356
369
logger .info ("Only found a single cohort. 'map-reduce' is preferred." )
357
370
return "map-reduce" , chunks_cohorts if merge else {}
358
371
359
- # 3 . Our dataset has chunksize one along the axis,
372
+ # 4 . Our dataset has chunksize one along the axis,
360
373
single_chunks = all (all (a == 1 for a in ac ) for ac in chunks )
361
- # 4 . Every chunk only has a single group, but that group might extend across multiple chunks
374
+ # 5 . Every chunk only has a single group, but that group might extend across multiple chunks
362
375
one_group_per_chunk = (bitmask .sum (axis = LABEL_AXIS ) == 1 ).all ()
363
- # 5 . Existing cohorts don't overlap, great for time grouping with perfect chunking
376
+ # 6 . Existing cohorts don't overlap, great for time grouping with perfect chunking
364
377
no_overlapping_cohorts = (np .bincount (np .concatenate (tuple (chunks_cohorts .keys ()))) == 1 ).all ()
365
378
if one_group_per_chunk or single_chunks or no_overlapping_cohorts :
366
379
logger .info ("find_group_cohorts: cohorts is preferred, chunking is perfect." )
@@ -393,6 +406,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
393
406
sparsity , MAX_SPARSITY_FOR_COHORTS
394
407
)
395
408
)
409
+ # 7. Groups seem fairly randomly distributed, use "map-reduce".
396
410
if sparsity > MAX_SPARSITY_FOR_COHORTS :
397
411
if not merge :
398
412
logger .info (
0 commit comments