Skip to content

Commit 627bf2b

Browse files
authored
Optimize bitmask finding for chunk size 1 and single chunk cases (#360)
* Optimize bitmask finding for chunk size 1. * Fix benchmark. * bugfix * Add single chunk benchmark * Optimize single chunk case. * Add test
1 parent 13cb229 commit 627bf2b

File tree

3 files changed

+64
-25
lines changed

3 files changed

+64
-25
lines changed

asv_bench/benchmarks/cohorts.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import cached_property
2+
13
import dask
24
import numpy as np
35
import pandas as pd
@@ -11,6 +13,10 @@ class Cohorts:
1113
def setup(self, *args, **kwargs):
1214
raise NotImplementedError
1315

16+
@cached_property
17+
def dask(self):
18+
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0].dask
19+
1420
def containment(self):
1521
asfloat = self.bitmask().astype(float)
1622
chunks_per_label = asfloat.sum(axis=0)
@@ -43,26 +49,17 @@ def time_find_group_cohorts(self):
4349
pass
4450

4551
def time_graph_construct(self):
46-
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts")
52+
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)
4753

4854
def track_num_tasks(self):
49-
result = flox.groupby_reduce(
50-
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
51-
)[0]
52-
return len(result.dask.to_dict())
55+
return len(self.dask.to_dict())
5356

5457
def track_num_tasks_optimized(self):
55-
result = flox.groupby_reduce(
56-
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
57-
)[0]
58-
(opt,) = dask.optimize(result)
59-
return len(opt.dask.to_dict())
58+
(opt,) = dask.optimize(self.dask)
59+
return len(opt.to_dict())
6060

6161
def track_num_layers(self):
62-
result = flox.groupby_reduce(
63-
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
64-
)[0]
65-
return len(result.dask.layers)
62+
return len(self.dask.layers)
6663

6764
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
6865
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
@@ -193,6 +190,19 @@ def setup(self, *args, **kwargs):
193190
self.expected = pd.RangeIndex(self.by.max() + 1)
194191

195192

193+
class SingleChunk(Cohorts):
194+
"""Single chunk along reduction axis: always blockwise."""
195+
196+
def setup(self, *args, **kwargs):
197+
index = pd.date_range("1959-01-01", freq="D", end="1962-12-31")
198+
self.time = pd.Series(index)
199+
TIME = len(self.time)
200+
self.axis = (2,)
201+
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, -1))
202+
self.by = codes_for_resampling(index, freq="5D")
203+
self.expected = pd.RangeIndex(self.by.max() + 1)
204+
205+
196206
class OISST(Cohorts):
197207
def setup(self, *args, **kwargs):
198208
self.array = dask.array.ones((1, 14532), chunks=(1, 10))

flox/core.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,22 @@ def slices_from_chunks(chunks):
248248

249249

250250
def _compute_label_chunk_bitmask(labels, chunks, nlabels):
251+
def make_bitmask(rows, cols):
252+
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows.shape)
253+
return csc_array((data, (rows, cols)), dtype=bool, shape=(nchunks, nlabels))
254+
251255
assert isinstance(labels, np.ndarray)
252256
shape = tuple(sum(c) for c in chunks)
253257
nchunks = math.prod(len(c) for c in chunks)
254258

255-
labels = np.broadcast_to(labels, shape[-labels.ndim :])
259+
# Shortcut for 1D with size-1 chunks
260+
if shape == (nchunks,):
261+
rows_array = np.arange(nchunks)
262+
cols_array = labels
263+
mask = labels >= 0
264+
return make_bitmask(rows_array[mask], cols_array[mask])
256265

266+
labels = np.broadcast_to(labels, shape[-labels.ndim :])
257267
cols = []
258268
# Add one to handle the -1 sentinel value
259269
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
@@ -272,10 +282,8 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
272282
label_is_present[:] = False
273283
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
274284
cols_array = np.concatenate(cols)
275-
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
276-
bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
277285

278-
return bitmask
286+
return make_bitmask(rows_array, cols_array)
279287

280288

281289
# @memoize
@@ -312,13 +320,18 @@ def find_group_cohorts(
312320
labels = np.asarray(labels)
313321

314322
shape = tuple(sum(c) for c in chunks)
323+
nchunks = math.prod(len(c) for c in chunks)
315324

316325
# assumes that `labels` are factorized
317326
if expected_groups is None:
318327
nlabels = labels.max() + 1
319328
else:
320329
nlabels = expected_groups[-1] + 1
321330

331+
# 1. Single chunk, blockwise always
332+
if nchunks == 1:
333+
return "blockwise", {(0,): list(range(nlabels))}
334+
322335
labels = np.broadcast_to(labels, shape[-labels.ndim :])
323336
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
324337

@@ -346,21 +359,21 @@ def invert(x) -> tuple[np.ndarray, ...]:
346359

347360
chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
348361

349-
# 1. Every group is contained to one block, use blockwise here.
362+
# 2. Every group is contained to one block, use blockwise here.
350363
if bitmask.shape[CHUNK_AXIS] == 1 or (chunks_per_label == 1).all():
351364
logger.info("find_group_cohorts: blockwise is preferred.")
352365
return "blockwise", chunks_cohorts
353366

354-
# 2. Perfectly chunked so there is only a single cohort
367+
# 3. Perfectly chunked so there is only a single cohort
355368
if len(chunks_cohorts) == 1:
356369
logger.info("Only found a single cohort. 'map-reduce' is preferred.")
357370
return "map-reduce", chunks_cohorts if merge else {}
358371

359-
# 3. Our dataset has chunksize one along the axis,
372+
# 4. Our dataset has chunksize one along the axis,
360373
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
361-
# 4. Every chunk only has a single group, but that group might extend across multiple chunks
374+
# 5. Every chunk only has a single group, but that group might extend across multiple chunks
362375
one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
363-
# 5. Existing cohorts don't overlap, great for time grouping with perfect chunking
376+
# 6. Existing cohorts don't overlap, great for time grouping with perfect chunking
364377
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
365378
if one_group_per_chunk or single_chunks or no_overlapping_cohorts:
366379
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
@@ -393,6 +406,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
393406
sparsity, MAX_SPARSITY_FOR_COHORTS
394407
)
395408
)
409+
# 7. Groups seem fairly randomly distributed, use "map-reduce".
396410
if sparsity > MAX_SPARSITY_FOR_COHORTS:
397411
if not merge:
398412
logger.info(

tests/test_core.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,12 +946,12 @@ def test_verify_complex_cohorts(chunksize: int) -> None:
946946
@pytest.mark.parametrize("chunksize", (12,) + tuple(range(1, 13)) + (-1,))
947947
def test_method_guessing(chunksize):
948948
# just a regression test
949-
labels = np.tile(np.arange(1, 13), 30)
949+
labels = np.tile(np.arange(0, 12), 30)
950950
by = dask.array.from_array(labels, chunks=chunksize) - 1
951951
preferred_method, chunks_cohorts = find_group_cohorts(labels, by.chunks[slice(-1, None)])
952952
if chunksize == -1:
953953
assert preferred_method == "blockwise"
954-
assert chunks_cohorts == {(0,): list(range(1, 13))}
954+
assert chunks_cohorts == {(0,): list(range(12))}
955955
elif chunksize in (1, 2, 3, 4, 6):
956956
assert preferred_method == "cohorts"
957957
assert len(chunks_cohorts) == 12 // chunksize
@@ -960,6 +960,21 @@ def test_method_guessing(chunksize):
960960
assert chunks_cohorts == {}
961961

962962

963+
@requires_dask
964+
@pytest.mark.parametrize("ndim", [1, 2, 3])
965+
def test_single_chunk_method_is_blockwise(ndim):
966+
for by_ndim in range(1, ndim + 1):
967+
chunks = (5,) * (ndim - by_ndim) + (-1,) * by_ndim
968+
assert len(chunks) == ndim
969+
array = dask.array.ones(shape=(10,) * ndim, chunks=chunks)
970+
by = np.zeros(shape=(10,) * by_ndim, dtype=int)
971+
method, chunks_cohorts = find_group_cohorts(
972+
by, chunks=[array.chunks[ax] for ax in range(-by.ndim, 0)]
973+
)
974+
assert method == "blockwise"
975+
assert chunks_cohorts == {(0,): [0]}
976+
977+
963978
@requires_dask
964979
@pytest.mark.parametrize(
965980
"chunk_at,expected",

0 commit comments

Comments
 (0)