Skip to content

Commit 839b4d0

Browse files
authored
Auto rechunk to enable blockwise reduction (#380)
1 parent 97d408e commit 839b4d0

File tree

5 files changed

+167
-28
lines changed

5 files changed

+167
-28
lines changed

flox/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
"""Top-level module for flox ."""
44

55
from . import cache
6-
from .aggregations import Aggregation, Scan # noqa
6+
from .aggregations import Aggregation, Scan
77
from .core import (
88
groupby_reduce,
99
groupby_scan,
1010
rechunk_for_blockwise,
1111
rechunk_for_cohorts,
1212
ReindexStrategy,
1313
ReindexArrayType,
14-
) # noqa
14+
)
15+
from .options import set_options
1516

1617

1718
def _get_version():
@@ -24,3 +25,15 @@ def _get_version():
2425

2526

2627
__version__ = _get_version()
28+
29+
__all__ = [
30+
"Aggregation",
31+
"Scan",
32+
"groupby_reduce",
33+
"groupby_scan",
34+
"rechunk_for_blockwise",
35+
"rechunk_for_cohorts",
36+
"set_options",
37+
"ReindexStrategy",
38+
"ReindexArrayType",
39+
]

flox/core.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from .cache import memoize
5050
from .lib import ArrayLayer, dask_array_type, sparse_array_type
51+
from .options import OPTIONS
5152
from .xrutils import (
5253
_contains_cftime_datetimes,
5354
_to_pytimedelta,
@@ -111,6 +112,7 @@
111112
# _simple_combine.
112113
DUMMY_AXIS = -2
113114

115+
114116
logger = logging.getLogger("flox")
115117

116118

@@ -215,8 +217,11 @@ def identity(x: T) -> T:
215217
return x
216218

217219

218-
def _issorted(arr: np.ndarray) -> bool:
219-
return bool((arr[:-1] <= arr[1:]).all())
220+
def _issorted(arr: np.ndarray, ascending=True) -> bool:
221+
if ascending:
222+
return bool((arr[:-1] <= arr[1:]).all())
223+
else:
224+
return bool((arr[:-1] >= arr[1:]).all())
220225

221226

222227
def _is_arg_reduction(func: T_Agg) -> bool:
@@ -299,7 +304,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
299304
def _get_optimal_chunks_for_groups(chunks, labels):
300305
chunkidx = np.cumsum(chunks) - 1
301306
# what are the groups at chunk boundaries
302-
labels_at_chunk_bounds = _unique(labels[chunkidx])
307+
labels_at_chunk_bounds = pd.unique(labels[chunkidx])
303308
# what's the last index of all groups
304309
last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last")
305310
# what's the last index of groups at the chunk boundaries.
@@ -317,6 +322,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
317322
Δl = abs(c - l)
318323
if c == 0 or newchunkidx[-1] > l:
319324
continue
325+
f = f.item() # noqa
326+
l = l.item() # noqa
320327
if Δf < Δl and f > newchunkidx[-1]:
321328
newchunkidx.append(f)
322329
else:
@@ -708,7 +715,9 @@ def rechunk_for_cohorts(
708715
return array.rechunk({axis: newchunks})
709716

710717

711-
def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray:
718+
def rechunk_for_blockwise(
719+
array: DaskArray, axis: T_Axis, labels: np.ndarray, *, force: bool = True
720+
) -> tuple[T_MethodOpt, DaskArray]:
712721
"""
713722
Rechunks array so that group boundaries line up with chunk boundaries, allowing
714723
embarrassingly parallel group reductions.
@@ -731,14 +740,47 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
731740
DaskArray
732741
Rechunked array
733742
"""
734-
# TODO: this should be unnecessary?
735-
labels = factorize_((labels,), axes=())[0]
743+
736744
chunks = array.chunks[axis]
737-
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
745+
if len(chunks) == 1:
746+
return "blockwise", array
747+
748+
# import dask
749+
# from dask.utils import parse_bytes
750+
# factor = parse_bytes(dask.config.get("array.chunk-size")) / (
751+
# math.prod(array.chunksize) * array.dtype.itemsize
752+
# )
753+
# if factor > BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR:
754+
# new_constant_chunks = math.ceil(factor) * max(chunks)
755+
# q, r = divmod(array.shape[axis], new_constant_chunks)
756+
# new_input_chunks = (new_constant_chunks,) * q + (r,)
757+
# else:
758+
new_input_chunks = chunks
759+
760+
# FIXME: this should be unnecessary?
761+
labels = factorize_((labels,), axes=())[0]
762+
newchunks = _get_optimal_chunks_for_groups(new_input_chunks, labels)
738763
if newchunks == chunks:
739-
return array
764+
return "blockwise", array
765+
766+
Δn = abs(len(newchunks) - len(new_input_chunks))
767+
if pass_num_chunks_threshold := (
768+
Δn / len(new_input_chunks) < OPTIONS["rechunk_blockwise_num_chunks_threshold"]
769+
):
770+
logger.debug("blockwise rechunk passes num chunks threshold")
771+
if pass_chunk_size_threshold := (
772+
# we just pick the max because number of chunks may have changed.
773+
(abs(max(newchunks) - max(new_input_chunks)) / max(new_input_chunks))
774+
< OPTIONS["rechunk_blockwise_chunk_size_threshold"]
775+
):
776+
logger.debug("blockwise rechunk passes chunk size change threshold")
777+
778+
if force or (pass_num_chunks_threshold and pass_chunk_size_threshold):
779+
logger.debug("Rechunking to enable blockwise.")
780+
return "blockwise", array.rechunk({axis: newchunks})
740781
else:
741-
return array.rechunk({axis: newchunks})
782+
logger.debug("Didn't meet thresholds to do automatic rechunking for blockwise reductions.")
783+
return None, array
742784

743785

744786
def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
@@ -2704,6 +2746,11 @@ def groupby_reduce(
27042746
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
27052747
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
27062748

2749+
if method is None and is_duck_dask_array(array) and not any_by_dask and by_.ndim == 1 and _issorted(by_):
2750+
# Let's try rechunking for sorted 1D by.
2751+
(single_axis,) = axis_
2752+
method, array = rechunk_for_blockwise(array, single_axis, by_, force=False)
2753+
27072754
is_first_last = _is_first_last_reduction(func)
27082755
if is_first_last:
27092756
if has_dask and nax != 1:
@@ -2891,7 +2938,7 @@ def groupby_reduce(
28912938

28922939
# if preferred method is already blockwise, no need to rechunk
28932940
if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1:
2894-
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
2941+
_, array = rechunk_for_blockwise(array, axis=-1, labels=by_)
28952942

28962943
result, groups = partial_agg(
28972944
array=array,

flox/options.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Started from xarray options.py; vendored from cf-xarray
3+
"""
4+
5+
import copy
6+
from collections.abc import MutableMapping
7+
from typing import Any
8+
9+
OPTIONS: MutableMapping[str, Any] = {
10+
# Thresholds below which we will automatically rechunk to blockwise if it makes sense
11+
# 1. Fractional change in number of chunks after rechunking
12+
"rechunk_blockwise_num_chunks_threshold": 0.25,
13+
# 2. Fractional change in max chunk size after rechunking
14+
"rechunk_blockwise_chunk_size_threshold": 1.5,
15+
# 3. If input arrays have chunk size smaller than `dask.array.chunk-size`,
16+
# then adjust chunks to meet that size first.
17+
# "rechunk.blockwise.chunk_size_factor": 1.5,
18+
}
19+
20+
21+
class set_options: # numpydoc ignore=PR01,PR02
22+
"""
23+
Set options for cf-xarray in a controlled context.
24+
25+
Parameters
26+
----------
27+
rechunk_blockwise_num_chunks_threshold : float
28+
Rechunk if fractional change in number of chunks after rechunking
29+
is less than this amount.
30+
rechunk_blockwise_chunk_size_threshold: float
31+
Rechunk if fractional change in max chunk size after rechunking
32+
is less than this threshold.
33+
34+
Examples
35+
--------
36+
37+
You can use ``set_options`` either as a context manager:
38+
39+
>>> import flox
40+
>>> with flox.set_options(rechunk_blockwise_num_chunks_threshold=1):
41+
... pass
42+
43+
Or to set global options:
44+
45+
>>> flox.set_options(rechunk_blockwise_num_chunks_threshold=1):
46+
"""
47+
48+
def __init__(self, **kwargs):
49+
self.old = {}
50+
for k in kwargs:
51+
if k not in OPTIONS:
52+
raise ValueError(f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}")
53+
self.old[k] = OPTIONS[k]
54+
self._apply_update(kwargs)
55+
56+
def _apply_update(self, options_dict):
57+
options_dict = copy.deepcopy(options_dict)
58+
OPTIONS.update(options_dict)
59+
60+
def __enter__(self):
61+
return
62+
63+
def __exit__(self, type, value, traceback):
64+
self._apply_update(self.old)

flox/xarray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pandas as pd
8+
import toolz
89
import xarray as xr
910
from packaging.version import Version
1011

@@ -589,7 +590,7 @@ def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_Data
589590
DataArray or Dataset
590591
Xarray object with rechunked arrays.
591592
"""
592-
return _rechunk(rechunk_array_for_blockwise, obj, dim, labels)
593+
return _rechunk(toolz.compose(toolz.last, rechunk_array_for_blockwise), obj, dim, labels)
593594

594595

595596
def _rechunk(func, obj, dim, labels, **kwargs):

tests/test_core.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from numpy_groupies.aggregate_numpy import aggregate
1515

1616
import flox
17+
from flox import set_options, xrutils
1718
from flox import xrdtypes as dtypes
18-
from flox import xrutils
1919
from flox.aggregations import Aggregation, _initialize_aggregation
2020
from flox.core import (
2121
HAS_NUMBAGG,
@@ -31,6 +31,7 @@
3131
find_group_cohorts,
3232
groupby_reduce,
3333
groupby_scan,
34+
rechunk_for_blockwise,
3435
rechunk_for_cohorts,
3536
reindex_,
3637
subset_to_blocks,
@@ -979,26 +980,39 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
979980
assert_equal(actual, expected)
980981

981982

983+
@requires_dask
982984
@pytest.mark.parametrize(
983-
"inchunks, expected",
985+
"inchunks, expected, expected_method",
984986
[
985-
[(1,) * 10, (3, 2, 2, 3)],
986-
[(2,) * 5, (3, 2, 2, 3)],
987-
[(3, 3, 3, 1), (3, 2, 5)],
988-
[(3, 1, 1, 2, 1, 1, 1), (3, 2, 2, 3)],
989-
[(3, 2, 2, 3), (3, 2, 2, 3)],
990-
[(4, 4, 2), (3, 4, 3)],
991-
[(5, 5), (5, 5)],
992-
[(6, 4), (5, 5)],
993-
[(7, 3), (7, 3)],
994-
[(8, 2), (7, 3)],
995-
[(9, 1), (10,)],
996-
[(10,), (10,)],
987+
[(1,) * 10, (3, 2, 2, 3), None],
988+
[(2,) * 5, (3, 2, 2, 3), None],
989+
[(3, 3, 3, 1), (3, 2, 5), None],
990+
[(3, 1, 1, 2, 1, 1, 1), (3, 2, 2, 3), None],
991+
[(3, 2, 2, 3), (3, 2, 2, 3), "blockwise"],
992+
[(4, 4, 2), (3, 4, 3), None],
993+
[(5, 5), (5, 5), "blockwise"],
994+
[(6, 4), (5, 5), None],
995+
[(7, 3), (7, 3), "blockwise"],
996+
[(8, 2), (7, 3), None],
997+
[(9, 1), (10,), None],
998+
[(10,), (10,), "blockwise"],
997999
],
9981000
)
999-
def test_rechunk_for_blockwise(inchunks, expected):
1001+
def test_rechunk_for_blockwise(inchunks, expected, expected_method):
10001002
labels = np.array([1, 1, 1, 2, 2, 3, 3, 5, 5, 5])
10011003
assert _get_optimal_chunks_for_groups(inchunks, labels) == expected
1004+
# reversed
1005+
assert _get_optimal_chunks_for_groups(inchunks, labels[::-1]) == expected
1006+
1007+
with set_options(rechunk_blockwise_chunk_size_threshold=-1):
1008+
array = dask.array.ones(labels.size, chunks=(inchunks,))
1009+
method, array = rechunk_for_blockwise(array, -1, labels, force=False)
1010+
assert method == expected_method
1011+
assert array.chunks == (inchunks,)
1012+
1013+
method, array = rechunk_for_blockwise(array, -1, labels[::-1], force=False)
1014+
assert method == expected_method
1015+
assert array.chunks == (inchunks,)
10021016

10031017

10041018
@requires_dask

0 commit comments

Comments
 (0)