From 3c681493d823161a384e47607d4395251cda8689 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 16 Jul 2025 15:32:51 -0600 Subject: [PATCH 1/3] refactor: extract scan and factorize functions to separate modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create flox/scan.py with groupby_scan and related functions - Create flox/factorize.py with factorize_ and related functions - Move scan-related types to types.py under TYPE_CHECKING - Update imports across codebase (core, tests, benchmarks) - Maintain backward compatibility through __init__.py exports This improves code organization by separating scan operations and factorization logic from the large core.py module while preserving all existing functionality. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- asv_bench/benchmarks/cohorts.py | 5 +- flox/__init__.py | 2 +- flox/core.py | 735 +------------------------------- flox/factorize.py | 300 +++++++++++++ flox/reindex.py | 149 +++++++ flox/scan.py | 325 ++++++++++++++ flox/types.py | 51 ++- flox/utils.py | 40 ++ tests/test_core.py | 10 +- tests/test_properties.py | 3 +- 10 files changed, 892 insertions(+), 728 deletions(-) create mode 100644 flox/factorize.py create mode 100644 flox/reindex.py create mode 100644 flox/scan.py create mode 100644 flox/utils.py diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index 8fa841fdd..8acab92f9 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -5,6 +5,7 @@ import pandas as pd import flox +from flox.factorize import _factorize_multiple from .helpers import codes_for_resampling @@ -89,7 +90,7 @@ def setup(self, *args, **kwargs): y = np.repeat(np.arange(30), 60) by = x[np.newaxis, :] * y[:, np.newaxis] - self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0] + self.by = _factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0] self.array = dask.array.ones(self.by.shape, chunks=(350, 350)) self.axis = (-2, -1) @@ -149,7 +150,7 @@ class ERA5MonthHour(ERA5Dataset, Cohorts): def setup(self, *args, **kwargs): super().__init__() by = (self.time.dt.month.values, self.time.dt.hour.values) - ret = flox.core._factorize_multiple( + ret = _factorize_multiple( by, (pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))), any_by_dask=False, diff --git a/flox/__init__.py b/flox/__init__.py index 898c10e24..a4e147332 100644 --- a/flox/__init__.py +++ b/flox/__init__.py @@ -6,12 +6,12 @@ from .aggregations import Aggregation, Scan # noqa from .core import ( groupby_reduce, - groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts, ReindexStrategy, ReindexArrayType, ) # noqa +from .scan import groupby_scan # noqa def _get_version(): diff --git a/flox/core.py b/flox/core.py index eb64d1370..c69f275a8 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1,18 +1,15 @@ from __future__ import annotations -import copy import datetime import itertools import logging import math import operator import warnings -from collections import namedtuple from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum, auto -from functools import partial, reduce +from functools import partial from itertools import product from numbers import Integral from typing import ( @@ -20,10 +17,8 @@ Any, Literal, TypeAlias, - TypedDict, TypeVar, Union, - cast, overload, ) @@ -36,18 +31,17 @@ from . import xrdtypes from .aggregate_flox import _prepare_for_flox from .aggregations import ( - AGGREGATIONS, Aggregation, - AlignedArrays, - Scan, - ScanState, _atleast_1d, _initialize_aggregation, generic_aggregate, quantile_new_dims_func, ) from .cache import memoize +from .factorize import _factorize_multiple, factorize_ from .lib import ArrayLayer, dask_array_type, sparse_array_type +from .reindex import reindex_ +from .utils import ReindexArrayType from .xrutils import ( _contains_cftime_datetimes, _to_pytimedelta, @@ -88,7 +82,6 @@ T_Func = str | Callable T_Funcs = T_Func | Sequence[T_Func] T_Agg = str | Aggregation - T_Scan = str | Scan T_Axis = int T_Axes = tuple[T_Axis, ...] T_AxesOpt = T_Axis | T_Axes | None @@ -104,7 +97,6 @@ IntermediateDict = dict[str | Callable, Any] FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]] -FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask") # This dummy axis is inserted using np.expand_dims # and then reduced over during the combine stage by @@ -114,36 +106,6 @@ logger = logging.getLogger("flox") -class ReindexArrayType(Enum): - """ - Enum describing which array type to reindex to. - - These are enumerated, rather than accepting a constructor, - because we might want to optimize for specific array types, - and because they don't necessarily have the same signature. - - For example, scipy.sparse.COO only supports a fill_value of 0. - """ - - AUTO = auto() - NUMPY = auto() - SPARSE_COO = auto() - # Sadly, scipy.sparse.coo_array only supports fill_value = 0 - # SCIPY_SPARSE_COO = auto() - # SPARSE_GCXS = auto() - - def is_same_type(self, other) -> bool: - match self: - case ReindexArrayType.AUTO: - return True - case ReindexArrayType.NUMPY: - return isinstance(other, np.ndarray) - case ReindexArrayType.SPARSE_COO: - import sparse - - return isinstance(other, sparse.COO) - - @dataclass class ReindexStrategy: """ @@ -183,16 +145,6 @@ def get_dask_meta(self, other, *, fill_value, dtype) -> Any: return sparse.COO.from_numpy(np.ones(shape=(0,) * other.ndim, dtype=dtype), fill_value=fill_value) -class FactorizeKwargs(TypedDict, total=False): - """Used in _factorize_multiple""" - - by: T_Bys - axes: T_Axes - fastpath: bool - reindex: bool - sort: bool - - def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): """Account for numbagg not providing a fill_value kwarg.""" from .aggregate_numbagg import DEFAULT_FILL_VALUE @@ -741,316 +693,6 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> return array.rechunk({axis: newchunks}) -def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int): - idx = from_.get_indexer(to) - indexer = [slice(None, None)] * array.ndim - indexer[axis] = idx - reindexed = array[tuple(indexer)] - if (idx == -1).any(): - if fill_value is None: - raise ValueError("Filling is required. fill_value cannot be None.") - indexer[axis] = idx == -1 - reindexed = reindexed.astype(dtype, copy=False) - reindexed[tuple(indexer)] = fill_value - return reindexed - - -def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int): - import sparse - - assert axis == -1 - - # Are there any elements in `to` that are not in `from_`. - if isinstance(to, pd.RangeIndex) and len(to) > len(from_): - # 1. pandas optimizes set difference between two RangeIndexes only - # 2. We want to avoid realizing a very large numpy array in to memory. - # This happens in the `else` clause. - # There are potentially other tricks we can play, but this is a simple - # and effective one. If a user is reindexing to sparse, then len(to) is - # almost guaranteed to be > len(from_). If len(to) <= len(from_), then realizing - # another array of the same shape should be fine. - needs_reindex = True - else: - needs_reindex = (from_.get_indexer(to) == -1).any() - - if needs_reindex and fill_value is None: - raise ValueError("Filling is required. fill_value cannot be None.") - - idx = to.get_indexer(from_) - mask = idx != -1 # indices along last axis to keep - if mask.all(): - mask = slice(None) - shape = array.shape - - if isinstance(array, sparse.COO): - subset = array[..., mask] - data = subset.data - coords = subset.coords - if subset.nnz > 0: - coords[-1, :] = idx[mask][coords[-1, :]] - if fill_value is None: - # no reindexing is actually needed (dense case) - # preserve the fill_value - fill_value = array.fill_value - else: - ranges = np.broadcast_arrays( - *np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],))) - ) - coords = np.stack(ranges, axis=0).reshape(array.ndim, -1) - data = array[..., mask].reshape(-1) - - reindexed = sparse.COO( - coords=coords, - data=data.astype(dtype, copy=False), - shape=(*array.shape[:axis], to.size), - fill_value=fill_value, - ) - - return reindexed - - -def reindex_( - array: np.ndarray, - from_, - to, - *, - array_type: ReindexArrayType = ReindexArrayType.AUTO, - fill_value: Any = None, - axis: T_Axis = -1, - promote: bool = False, -) -> np.ndarray: - if not isinstance(to, pd.Index): - if promote: - to = pd.Index(to) - else: - raise ValueError("reindex requires a pandas.Index or promote=True") - - if to.ndim > 1: - raise ValueError(f"Cannot reindex to a multidimensional array: {to}") - - if array.shape[axis] == 0: - # all groups were NaN - shape = array.shape[:-1] + (len(to),) - if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY): - reindexed = np.full(shape, fill_value, dtype=array.dtype) - else: - raise NotImplementedError - return reindexed - - from_ = pd.Index(from_) - # short-circuit for trivial case - if from_.equals(to) and array_type.is_same_type(array): - return array - - if from_.dtype.kind == "O" and isinstance(from_[0], tuple): - raise NotImplementedError( - "Currently does not support reindexing with object arrays of tuples. " - "These occur when grouping by multi-indexed variables in xarray." - ) - if fill_value is xrdtypes.NA or isnull(fill_value): - new_dtype, fill_value = xrdtypes.maybe_promote(array.dtype) - else: - new_dtype = array.dtype - - if array_type is ReindexArrayType.AUTO: - if isinstance(array, sparse_array_type): - array_type = ReindexArrayType.SPARSE_COO - else: - # TODO: generalize here - # Right now, we effectively assume NEP-18 I think - array_type = ReindexArrayType.NUMPY - - if array_type is ReindexArrayType.NUMPY: - reindexed = reindex_numpy(array, from_, to, fill_value, new_dtype, axis) - elif array_type is ReindexArrayType.SPARSE_COO: - reindexed = reindex_pydata_sparse_coo(array, from_, to, fill_value, new_dtype, axis) - return reindexed - - -def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: - """ - Offset group labels by dimension. This is used when we - reduce over a subset of the dimensions of by. It assumes that the reductions - dimensions have been flattened in the last dimension - Copied from xhistogram & - https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy - """ - assert labels.ndim > 1 - offset: np.ndarray = ( - labels + np.arange(math.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups - ) - # -1 indicates NaNs. preserve these otherwise we aggregate in the wrong groups! - offset[labels == -1] = -1 - size: int = math.prod(labels.shape[:-1]) * ngroups - return offset, size - - -def _factorize_single(by, expect, *, sort: bool, reindex: bool) -> tuple[pd.Index, np.ndarray]: - flat = by.reshape(-1) - if isinstance(expect, pd.RangeIndex): - # idx is a view of the original `by` array - # copy here so we don't have a race condition with the - # group_idx[nanmask] = nan_sentinel assignment later - # this is important in shared-memory parallelism with dask - # TODO: figure out how to avoid this - idx = flat.copy() - found_groups = cast(pd.Index, expect) - # TODO: fix by using masked integers - idx[idx > expect[-1]] = -1 - - elif isinstance(expect, pd.IntervalIndex): - if expect.closed == "both": - raise NotImplementedError - bins = np.concatenate([expect.left.to_numpy(), expect.right.to_numpy()[[-1]]]) - - # digitize is 0 or idx.max() for values outside the bounds of all intervals - # make it behave like pd.cut which uses -1: - if len(bins) > 1: - right = expect.closed_right - idx = np.digitize( - flat, - bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins, - right=right, - ) - idx -= 1 - within_bins = flat <= bins.max() if right else flat < bins.max() - idx[~within_bins] = -1 - else: - idx = np.zeros_like(flat, dtype=np.intp) - 1 - found_groups = cast(pd.Index, expect) - else: - if expect is not None and reindex: - sorter = np.argsort(expect) - groups = expect[(sorter,)] if sort else expect - idx = np.searchsorted(expect, flat, sorter=sorter) - mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect)) - if not sort: - # idx is the index in to the sorted array. - # if we didn't want sorting, unsort it back - idx[(idx == len(expect),)] = -1 - idx = sorter[(idx,)] - idx[mask] = -1 - else: - idx, groups = pd.factorize(flat, sort=sort) - found_groups = cast(pd.Index, groups) - - return (found_groups, idx.reshape(by.shape)) - - -def _ravel_factorized(*factorized: np.ndarray, grp_shape: tuple[int, ...]) -> np.ndarray: - group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap") - # NaNs; as well as values outside the bins are coded by -1 - # Restore these after the raveling - nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized]) - group_idx[nan_by_mask] = -1 - return group_idx - - -@overload -def factorize_( - by: T_Bys, - axes: T_Axes, - *, - fastpath: Literal[True], - expected_groups: T_ExpectIndexOptTuple | None = None, - reindex: bool = False, - sort: bool = True, -) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, None]: ... - - -@overload -def factorize_( - by: T_Bys, - axes: T_Axes, - *, - expected_groups: T_ExpectIndexOptTuple | None = None, - reindex: bool = False, - sort: bool = True, - fastpath: Literal[False] = False, -) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps]: ... - - -@overload -def factorize_( - by: T_Bys, - axes: T_Axes, - *, - expected_groups: T_ExpectIndexOptTuple | None = None, - reindex: bool = False, - sort: bool = True, - fastpath: bool = False, -) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]: ... - - -def factorize_( - by: T_Bys, - axes: T_Axes, - *, - expected_groups: T_ExpectIndexOptTuple | None = None, - reindex: bool = False, - sort: bool = True, - fastpath: bool = False, -) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]: - """ - Returns an array of integer codes for groups (and associated data) - by wrapping pd.cut and pd.factorize (depending on isbin). - This method handles reindex and sort so that we don't spend time reindexing / sorting - a possibly large results array. Instead we set up the appropriate integer codes (group_idx) - so that the results come out in the appropriate order. - """ - if expected_groups is None: - expected_groups = (None,) * len(by) - - if len(by) > 2: - with ThreadPoolExecutor() as executor: - futures = [ - executor.submit(partial(_factorize_single, sort=sort, reindex=reindex), groupvar, expect) - for groupvar, expect in zip(by, expected_groups) - ] - results = tuple(f.result() for f in futures) - else: - results = tuple( - _factorize_single(groupvar, expect, sort=sort, reindex=reindex) - for groupvar, expect in zip(by, expected_groups) - ) - found_groups = tuple(r[0] for r in results) - factorized = [r[1] for r in results] - - grp_shape = tuple(len(grp) for grp in found_groups) - ngroups = math.prod(grp_shape) - if len(by) > 1: - group_idx = _ravel_factorized(*factorized, grp_shape=grp_shape) - else: - (group_idx,) = factorized - - if fastpath: - return group_idx, found_groups, grp_shape, ngroups, ngroups, None - - if len(axes) == 1 and by[0].ndim > 1: - # Not reducing along all dimensions of by - # this is OK because for 3D by and axis=(1,2), - # we collapse to a 2D by and axis=-1 - offset_group = True - group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups) - else: - size = ngroups - offset_group = False - - # numpy_groupies cannot deal with group_idx = -1 - # so we'll add use ngroups as the sentinel - # note we cannot simply remove the NaN locations; - # that would mess up argmax, argmin - nan_sentinel = size if offset_group else ngroups - nanmask = group_idx == -1 - if nanmask.any(): - # bump it up so there's a place to assign values to the nan_sentinel index - size += 1 - group_idx[nanmask] = nan_sentinel - - props = FactorProps(offset_group, nan_sentinel, nanmask) - return group_idx, tuple(found_groups), grp_shape, ngroups, size, props - - def chunk_argreduce( array_plus_idx: tuple[np.ndarray, ...], by: np.ndarray, @@ -1470,24 +1112,6 @@ def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes | None = None) -> np.nd # return concatenate3(mapped) -def reindex_intermediates( - x: IntermediateDict, agg: Aggregation, unique_groups, array_type -) -> IntermediateDict: - new_shape = x["groups"].shape[:-1] + (len(unique_groups),) - newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)} - newx["intermediates"] = tuple( - reindex_( - v, - from_=np.atleast_1d(x["groups"].squeeze()), - to=pd.Index(unique_groups), - fill_value=f, - array_type=array_type, - ) - for v, f in zip(x["intermediates"], agg.fill_value["intermediate"]) - ) - return newx - - def listify_groups(x: IntermediateDict): return list(np.atleast_1d(x["groups"].squeeze())) @@ -2315,68 +1939,6 @@ def _convert_expected_groups_to_index( return tuple(out) -def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray: - group_idx, *_ = factorize_(by, **kwargs) - return group_idx - - -def _factorize_multiple( - by: T_Bys, - expected_groups: T_ExpectIndexOptTuple, - any_by_dask: bool, - sort: bool = True, -) -> tuple[tuple[np.ndarray], tuple[pd.Index, ...], tuple[int, ...]]: - kwargs: FactorizeKwargs = dict( - axes=(), # always (), we offset later if necessary. - fastpath=True, - # This is the only way it makes sense I think. - # reindex controls what's actually allocated in chunk_reduce - # At this point, we care about an accurate conversion to codes. - reindex=True, - sort=sort, - ) - if any_by_dask: - import dask.array - - from . import dask_array_ops # noqa - - # unifying chunks will make sure all arrays in `by` are dask arrays - # with compatible chunks, even if there was originally a numpy array - inds = tuple(range(by[0].ndim)) - for by_, expect in zip(by, expected_groups): - if expect is None and is_duck_dask_array(by_): - raise ValueError("Please provide expected_groups when grouping by a dask array.") - - found_groups = tuple( - pd.Index(pd.unique(by_.reshape(-1))) if expect is None else expect - for by_, expect in zip(by, expected_groups) - ) - grp_shape = tuple(map(len, found_groups)) - - chunks, by_chunked = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by)))) - group_idxs = [ - dask.array.map_blocks( - _lazy_factorize_wrapper, - by_, - expected_groups=(expect_,), - meta=np.array((), dtype=np.int64), - **kwargs, - ) - for by_, expect_ in zip(by_chunked, expected_groups) - ] - # This could be avoied but we'd use `np.where` - # instead `_ravel_factorized` instead i.e. a copy. - group_idx = dask.array.map_blocks( - _ravel_factorized, *group_idxs, grp_shape=grp_shape, chunks=tuple(chunks.values()), dtype=np.int64 - ) - - else: - kwargs["by"] = by - group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs, expected_groups=expected_groups) - - return (group_idx,), found_groups, grp_shape - - @overload def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]: ... @@ -2953,281 +2515,18 @@ def groupby_reduce( return (result, *groups) -def groupby_scan( - array: np.ndarray | DaskArray, - *by: T_By, - func: T_Scan, - expected_groups: T_ExpectedGroupsOpt = None, - axis: int | tuple[int] = -1, - dtype: np.typing.DTypeLike = None, - method: T_MethodOpt = None, - engine: T_EngineOpt = None, -) -> np.ndarray | DaskArray: - """ - GroupBy reductions using parallel scans for dask.array - - Parameters - ---------- - array : ndarray or DaskArray - Array to be reduced, possibly nD - *by : ndarray or DaskArray - Array of labels to group over. Must be aligned with ``array`` so that - ``array.shape[-by.ndim :] == by.shape`` or any disagreements in that - equality check are for dimensions of size 1 in `by`. - func : {"nancumsum", "ffill", "bfill"} or Scan - Single function name or a Scan instance - expected_groups : (optional) Sequence - Expected unique labels. - axis : None or int or Sequence[int], optional - If None, reduce across all dimensions of by - Else, reduce across corresponding axes of array - Negative integers are normalized using array.ndim. - fill_value : Any - Value to assign when a label in ``expected_groups`` is not present. - dtype : data-type , optional - DType for the output. Can be anything that is accepted by ``np.dtype``. - method : {"blockwise", "cohorts"}, optional - Strategy for reduction of dask arrays only: - * ``"blockwise"``: - Only scan using blockwise and avoid aggregating blocks - together. Useful for resampling-style groupby problems where group - members are always together. If `by` is 1D, `array` is automatically - rechunked so that chunk boundaries line up with group boundaries - i.e. each block contains all members of any group present - in that block. For nD `by`, you must make sure that all members of a group - are present in a single block. - * ``"cohorts"``: - Finds group labels that tend to occur together ("cohorts"), - indexes out cohorts and reduces that subset using "map-reduce", - repeat for all cohorts. This works well for many time groupings - where the group labels repeat at regular intervals like 'hour', - 'month', dayofyear' etc. Optimize chunking ``array`` for this - method by first rechunking using ``rechunk_for_cohorts`` - (for 1D ``by`` only). - engine : {"flox", "numpy", "numba", "numbagg"}, optional - Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: - * ``"numpy"``: - Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``. - This is the default choice because it works for most array types. - * ``"flox"``: - Use an internal implementation where the data is sorted so that - all members of a group occur sequentially, and then numpy.ufunc.reduceat - is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy`` - for a reduction that is not yet implemented. - * ``"numba"``: - Use the implementations in ``numpy_groupies.aggregate_numba``. - * ``"numbagg"``: - Use the reductions supported by ``numbagg.grouped``. This will fall back to ``numpy_groupies.aggregate_numpy`` - for a reduction that is not yet implemented. - - Returns - ------- - result - Aggregated result - - See Also - -------- - xarray.xarray_reduce - """ - - axis = _atleast_1d(axis) - if len(axis) > 1: - raise NotImplementedError("Scans are only supported along a single dimension.") - - bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) - nby = len(by) - by_is_dask = tuple(is_duck_dask_array(b) for b in bys) - any_by_dask = any(by_is_dask) - - axis_ = normalize_axis_tuple(axis, array.ndim) - - if engine is not None: - raise NotImplementedError("Setting `engine` is not supported for scans yet.") - if method is not None: - raise NotImplementedError("Setting `method` is not supported for scans yet.") - if engine is None: - engine = "flox" - assert engine == "flox" - - if not is_duck_array(array): - array = np.asarray(array) - - if isinstance(func, str): - agg = AGGREGATIONS[func] - assert isinstance(agg, Scan) - agg = copy.deepcopy(agg) - - if (agg == AGGREGATIONS["ffill"] or agg == AGGREGATIONS["bfill"]) and array.dtype.kind != "f": - # nothing to do, no NaNs! - return array - - if expected_groups is not None: - raise NotImplementedError("Setting `expected_groups` and binning is not supported yet.") - expected_groups = _validate_expected_groups(nby, expected_groups) - expected_groups = _convert_expected_groups_to_index(expected_groups, isbin=(False,) * nby, sort=False) - - # Don't factorize early only when - # grouping by dask arrays, and not having expected_groups - factorize_early = not ( - # can't do it if we are grouping by dask array but don't have expected_groups - any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups)) - ) - if factorize_early: - bys, final_groups, grp_shape = _factorize_multiple( - bys, - expected_groups, - any_by_dask=any_by_dask, - sort=False, +def reindex_intermediates(x, agg, unique_groups, array_type): + """Reindex intermediate results for groupby operations.""" + new_shape = x["groups"].shape[:-1] + (len(unique_groups),) + newx = {"groups": np.broadcast_to(unique_groups, new_shape)} + newx["intermediates"] = tuple( + reindex_( + v, + from_=np.atleast_1d(x["groups"].squeeze()), + to=pd.Index(unique_groups), + fill_value=f, + array_type=array_type, ) - else: - raise NotImplementedError - - assert len(bys) == 1 - by_: np.ndarray - (by_,) = bys - has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) - - if array.dtype.kind in "Mm": - cast_to = array.dtype - array = array.view(np.int64) - elif array.dtype.kind == "b": - array = array.view(np.int8) - cast_to = None - if agg.preserves_dtype: - cast_to = bool - else: - cast_to = None - - # TODO: move to aggregate_npg.py - if agg.name in ["cumsum", "nancumsum"] and array.dtype.kind in ["i", "u"]: - # https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html - # it defaults to the dtype of a, unless a - # has an integer dtype with a precision less than that of the default platform integer. - if array.dtype.kind == "i": - agg.dtype = np.result_type(array.dtype, np.int_) - elif array.dtype.kind == "u": - agg.dtype = np.result_type(array.dtype, np.uint) - else: - agg.dtype = array.dtype if dtype is None else dtype - agg.identity = xrdtypes._get_fill_value(agg.dtype, agg.identity) - - (single_axis,) = axis_ # type: ignore[misc] - # avoid some roundoff error when we can. - if by_.shape[-1] == 1 or by_.shape == grp_shape: - array = array.astype(agg.dtype) - if cast_to is not None: - array = array.astype(cast_to) - return array - - # Made a design choice here to have `preprocess` handle both array and group_idx - # Example: for reversing, we need to reverse the whole array, not just reverse - # each block independently - inp = AlignedArrays(array=array, group_idx=by_) - if agg.preprocess: - inp = agg.preprocess(inp) - - if not has_dask: - final_state = chunk_scan(inp, axis=single_axis, agg=agg, dtype=agg.dtype) - result = _finalize_scan(final_state, dtype=agg.dtype) - else: - result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg) - - # Made a design choice here to have `postprocess` handle both array and group_idx - out = AlignedArrays(array=result, group_idx=by_) - if agg.finalize: - out = agg.finalize(out) - - if cast_to is not None: - return out.array.astype(cast_to) - return out.array - - -def chunk_scan(inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims=None) -> ScanState: - assert axis == inp.array.ndim - 1 - - # I don't think we need to re-factorize here unless we are grouping by a dask array - accumulated = generic_aggregate( - inp.group_idx, - inp.array, - axis=axis, - engine="flox", - func=agg.scan, - dtype=dtype, - fill_value=agg.identity, - ) - result = AlignedArrays(array=accumulated, group_idx=inp.group_idx) - return ScanState(result=result, state=None) - - -def grouped_reduce(inp: AlignedArrays, *, agg: Scan, axis: int, keepdims=None) -> ScanState: - assert axis == inp.array.ndim - 1 - reduced = chunk_reduce( - inp.array, - inp.group_idx, - func=(agg.reduction,), - axis=axis, - engine="flox", - dtype=inp.array.dtype, - fill_value=agg.identity, - expected_groups=None, - ) - return ScanState( - state=AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"]), - result=None, - ) - - -def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays: - return AlignedArrays(group_idx=group_idx, array=array) - - -def _finalize_scan(block: ScanState, dtype) -> np.ndarray: - assert block.result is not None - return block.result.array.astype(dtype, copy=False) - - -def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray: - from dask.array import map_blocks - from dask.array.reductions import cumreduction as scan - - from flox.aggregations import scan_binary_op - - if len(axes) > 1: - raise NotImplementedError("Scans are only supported along a single axis.") - (axis,) = axes - - array, by = _unify_chunks(array, by) - - # 1. zip together group indices & array - zipped = map_blocks( - _zip, - by, - array, - dtype=array.dtype, - meta=array._meta, - name="groupby-scan-preprocess", - ) - - scan_ = partial(chunk_scan, agg=agg) - # dask tokenizing error workaround - scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined] - - # 2. Run the scan - accumulated = scan( - func=scan_, - binop=partial(scan_binary_op, agg=agg), - ident=agg.identity, - x=zipped, - axis=axis, - # TODO: support method="sequential" here. - method="blelloch", - preop=partial(grouped_reduce, agg=agg), - dtype=agg.dtype, + for v, f in zip(x["intermediates"], agg.fill_value["intermediate"]) ) - - # 3. Unzip and extract the final result array, discard groups - result = map_blocks(partial(_finalize_scan, dtype=agg.dtype), accumulated, dtype=agg.dtype) - - assert result.chunks == array.chunks - - return result + return newx diff --git a/flox/factorize.py b/flox/factorize.py new file mode 100644 index 000000000..ced27119a --- /dev/null +++ b/flox/factorize.py @@ -0,0 +1,300 @@ +"""Factorization functions for groupby operations.""" + +from __future__ import annotations + +import itertools +import math +from collections import namedtuple +from concurrent.futures import ThreadPoolExecutor +from functools import partial, reduce +from typing import TYPE_CHECKING, Literal, TypedDict, cast, overload + +import numpy as np +import pandas as pd + +if TYPE_CHECKING: + from .types import ( + T_Axes, + T_By, + T_Bys, + T_ExpectIndexOptTuple, + ) + +# Type definitions +FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask") + + +class FactorizeKwargs(TypedDict, total=False): + """Used in _factorize_multiple""" + + by: T_Bys + axes: T_Axes + expected_groups: T_ExpectIndexOptTuple | None + fastpath: bool + reindex: bool + sort: bool + + +def _factorize_single(by, expect, *, sort: bool, reindex: bool) -> tuple[pd.Index, np.ndarray]: + # Import here to avoid circular imports + from .core import isnull + + flat = by.reshape(-1) + if isinstance(expect, pd.RangeIndex): + # idx is a view of the original `by` array + # copy here so we don't have a race condition with the + # group_idx[nanmask] = nan_sentinel assignment later + # this is important in shared-memory parallelism with dask + # TODO: figure out how to avoid this + idx = flat.copy() + found_groups = cast(pd.Index, expect) + # TODO: fix by using masked integers + idx[idx > expect[-1]] = -1 + + elif isinstance(expect, pd.IntervalIndex): + if expect.closed == "both": + raise NotImplementedError + bins = np.concatenate([expect.left.to_numpy(), expect.right.to_numpy()[[-1]]]) + + # digitize is 0 or idx.max() for values outside the bounds of all intervals + # make it behave like pd.cut which uses -1: + if len(bins) > 1: + right = expect.closed_right + idx = np.digitize( + flat, + bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins, + right=right, + ) + idx -= 1 + within_bins = flat <= bins.max() if right else flat < bins.max() + idx[~within_bins] = -1 + else: + idx = np.zeros_like(flat, dtype=np.intp) - 1 + found_groups = cast(pd.Index, expect) + else: + if expect is not None and reindex: + sorter = np.argsort(expect) + groups = expect[(sorter,)] if sort else expect + idx = np.searchsorted(expect, flat, sorter=sorter) + mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect)) + if not sort: + # idx is the index in to the sorted array. + # if we didn't want sorting, unsort it back + idx[(idx == len(expect),)] = -1 + idx = sorter[(idx,)] + idx[mask] = -1 + else: + idx, groups = pd.factorize(flat, sort=sort) + found_groups = cast(pd.Index, groups) + + return (found_groups, idx.reshape(by.shape)) + + +def _ravel_factorized(*factorized: np.ndarray, grp_shape: tuple[int, ...]) -> np.ndarray: + group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap") + # NaNs; as well as values outside the bins are coded by -1 + # Restore these after the raveling + nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized]) + group_idx[nan_by_mask] = -1 + return group_idx + + +def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: + """ + Offset group labels by dimension. This is used when we + reduce over a subset of the dimensions of by. It assumes that the reductions + dimensions have been flattened in the last dimension + Copied from xhistogram & + https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy + """ + assert labels.ndim > 1 + offset: np.ndarray = ( + labels + np.arange(math.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups + ) + # -1 indicates NaNs. preserve these otherwise we aggregate in the wrong groups! + offset[labels == -1] = -1 + size: int = math.prod(labels.shape[:-1]) * ngroups + return offset, size + + +@overload +def factorize_( + by: T_Bys, + axes: T_Axes, + *, + fastpath: Literal[True], + expected_groups: T_ExpectIndexOptTuple | None = None, + reindex: bool = False, + sort: bool = True, +) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, None]: ... + + +@overload +def factorize_( + by: T_Bys, + axes: T_Axes, + *, + expected_groups: T_ExpectIndexOptTuple | None = None, + reindex: bool = False, + sort: bool = True, + fastpath: Literal[False] = False, +) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps]: ... + + +@overload +def factorize_( + by: T_Bys, + axes: T_Axes, + *, + expected_groups: T_ExpectIndexOptTuple | None = None, + reindex: bool = False, + sort: bool = True, + fastpath: bool = False, +) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]: ... + + +def factorize_( + by: T_Bys, + axes: T_Axes, + *, + expected_groups: T_ExpectIndexOptTuple | None = None, + reindex: bool = False, + sort: bool = True, + fastpath: bool = False, +) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]: + """ + Returns an array of integer codes for groups (and associated data) + by wrapping pd.cut and pd.factorize (depending on isbin). + This method handles reindex and sort so that we don't spend time reindexing / sorting + a possibly large results array. Instead we set up the appropriate integer codes (group_idx) + so that the results come out in the appropriate order. + """ + # offset_labels is now defined in this module + + if expected_groups is None: + expected_groups = (None,) * len(by) + + if len(by) > 2: + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(partial(_factorize_single, sort=sort, reindex=reindex), groupvar, expect) + for groupvar, expect in zip(by, expected_groups) + ] + results = tuple(f.result() for f in futures) + else: + results = tuple( + _factorize_single(groupvar, expect, sort=sort, reindex=reindex) + for groupvar, expect in zip(by, expected_groups) + ) + found_groups = tuple(r[0] for r in results) + factorized = [r[1] for r in results] + + grp_shape = tuple(len(grp) for grp in found_groups) + ngroups = math.prod(grp_shape) + if len(by) > 1: + group_idx = _ravel_factorized(*factorized, grp_shape=grp_shape) + else: + (group_idx,) = factorized + + if fastpath: + return group_idx, found_groups, grp_shape, ngroups, ngroups, None + + if len(axes) == 1 and by[0].ndim > 1: + # Not reducing along all dimensions of by + # this is OK because for 3D by and axis=(1,2), + # we collapse to a 2D by and axis=-1 + offset_group = True + group_idx, size = offset_labels(group_idx.reshape(by[0].shape), ngroups) + else: + size = ngroups + offset_group = False + + # numpy_groupies cannot deal with group_idx = -1 + # so we'll add use ngroups as the sentinel + # note we cannot simply remove the NaN locations; + # that would mess up argmax, argmin + nan_sentinel = size if offset_group else ngroups + nanmask = group_idx == -1 + if nanmask.any(): + # bump it up so there's a place to assign values to the nan_sentinel index + size += 1 + group_idx[nanmask] = nan_sentinel + + props = FactorProps(offset_group, nan_sentinel, nanmask) + return group_idx, tuple(found_groups), grp_shape, ngroups, size, props + + +def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray: + group_idx, *_ = factorize_(by, **kwargs) + return group_idx + + +def _factorize_multiple( + by: T_Bys, + expected_groups: T_ExpectIndexOptTuple, + any_by_dask: bool, + sort: bool = True, +) -> tuple[tuple[np.ndarray], tuple[pd.Index, ...], tuple[int, ...]]: + # Import here to avoid circular imports + from .core import is_duck_dask_array + + kwargs: FactorizeKwargs = dict( + axes=(), # always (), we offset later if necessary. + fastpath=True, + # This is the only way it makes sense I think. + # reindex controls what's actually allocated in chunk_reduce + # At this point, we care about an accurate conversion to codes. + reindex=True, + sort=sort, + ) + if any_by_dask: + import dask.array + + from . import dask_array_ops # noqa + + # unifying chunks will make sure all arrays in `by` are dask arrays + # with compatible chunks, even if there was originally a numpy array + inds = tuple(range(by[0].ndim)) + for by_, expect in zip(by, expected_groups): + if expect is None and is_duck_dask_array(by_): + raise ValueError("Please provide expected_groups when grouping by a dask array.") + + found_groups = tuple( + pd.Index(pd.unique(by_.reshape(-1))) if expect is None else expect + for by_, expect in zip(by, expected_groups) + ) + grp_shape = tuple(map(len, found_groups)) + + chunks, by_chunked = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by)))) + group_idxs = [ + dask.array.map_blocks( + _lazy_factorize_wrapper, + by_, + expected_groups=(expect_,), + meta=np.array((), dtype=np.int64), + **kwargs, + ) + for by_, expect_ in zip(by_chunked, expected_groups) + ] + # This could be avoied but we'd use `np.where` + # instead `_ravel_factorized` instead i.e. a copy. + group_idx = dask.array.map_blocks( + _ravel_factorized, *group_idxs, grp_shape=grp_shape, chunks=tuple(chunks.values()), dtype=np.int64 + ) + + else: + kwargs["by"] = by + group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs, expected_groups=expected_groups) + + return (group_idx,), found_groups, grp_shape + + +__all__ = [ + "factorize_", + "_factorize_single", + "_factorize_multiple", + "_lazy_factorize_wrapper", + "_ravel_factorized", + "offset_labels", + "FactorProps", +] diff --git a/flox/reindex.py b/flox/reindex.py new file mode 100644 index 000000000..1732e7f11 --- /dev/null +++ b/flox/reindex.py @@ -0,0 +1,149 @@ +"""Reindexing functions for groupby operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd + +if TYPE_CHECKING: + from .core import T_Axis + +from . import xrdtypes +from .lib import sparse_array_type +from .utils import ReindexArrayType +from .xrutils import isnull + + +def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int): + idx = from_.get_indexer(to) + indexer = [slice(None, None)] * array.ndim + indexer[axis] = idx + reindexed = array[tuple(indexer)] + if (idx == -1).any(): + if fill_value is None: + raise ValueError("Filling is required. fill_value cannot be None.") + indexer[axis] = idx == -1 + reindexed = reindexed.astype(dtype, copy=False) + reindexed[tuple(indexer)] = fill_value + return reindexed + + +def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int): + import sparse + + assert axis == -1 + + # Are there any elements in `to` that are not in `from_`. + if isinstance(to, pd.RangeIndex) and len(to) > len(from_): + # 1. pandas optimizes set difference between two RangeIndexes only + # 2. We want to avoid realizing a very large numpy array in to memory. + # This happens in the `else` clause. + # There are potentially other tricks we can play, but this is a simple + # and effective one. If a user is reindexing to sparse, then len(to) is + # almost guaranteed to be > len(from_). If len(to) <= len(from_), then realizing + # another array of the same shape should be fine. + needs_reindex = True + else: + needs_reindex = (from_.get_indexer(to) == -1).any() + + if needs_reindex and fill_value is None: + raise ValueError("Filling is required. fill_value cannot be None.") + + idx = to.get_indexer(from_) + mask = idx != -1 # indices along last axis to keep + if mask.all(): + mask = slice(None) + shape = array.shape + + if isinstance(array, sparse.COO): + subset = array[..., mask] + data = subset.data + coords = subset.coords + if subset.nnz > 0: + coords[-1, :] = idx[mask][coords[-1, :]] + if fill_value is None: + # no reindexing is actually needed (dense case) + # preserve the fill_value + fill_value = array.fill_value + else: + ranges = np.broadcast_arrays( + *np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],))) + ) + coords = np.stack(ranges, axis=0).reshape(array.ndim, -1) + data = array[..., mask].reshape(-1) + + reindexed = sparse.COO( + coords=coords, + data=data.astype(dtype, copy=False), + shape=(*array.shape[:axis], to.size), + fill_value=fill_value, + ) + + return reindexed + + +def reindex_( + array: np.ndarray, + from_, + to, + *, + array_type: ReindexArrayType = ReindexArrayType.AUTO, + fill_value: Any = None, + axis: T_Axis = -1, + promote: bool = False, +) -> np.ndarray: + if not isinstance(to, pd.Index): + if promote: + to = pd.Index(to) + else: + raise ValueError("reindex requires a pandas.Index or promote=True") + + if to.ndim > 1: + raise ValueError(f"Cannot reindex to a multidimensional array: {to}") + + if array.shape[axis] == 0: + # all groups were NaN + shape = array.shape[:-1] + (len(to),) + if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY): + reindexed = np.full(shape, fill_value, dtype=array.dtype) + else: + raise NotImplementedError + return reindexed + + from_ = pd.Index(from_) + # short-circuit for trivial case + if from_.equals(to) and array_type.is_same_type(array): + return array + + if from_.dtype.kind == "O" and isinstance(from_[0], tuple): + raise NotImplementedError( + "Currently does not support reindexing with object arrays of tuples. " + "These occur when grouping by multi-indexed variables in xarray." + ) + if fill_value is xrdtypes.NA or isnull(fill_value): + new_dtype, fill_value = xrdtypes.maybe_promote(array.dtype) + else: + new_dtype = array.dtype + + if array_type is ReindexArrayType.AUTO: + if isinstance(array, sparse_array_type): + array_type = ReindexArrayType.SPARSE_COO + else: + # TODO: generalize here + # Right now, we effectively assume NEP-18 I think + array_type = ReindexArrayType.NUMPY + + if array_type is ReindexArrayType.NUMPY: + reindexed = reindex_numpy(array, from_, to, fill_value, new_dtype, axis) + elif array_type is ReindexArrayType.SPARSE_COO: + reindexed = reindex_pydata_sparse_coo(array, from_, to, fill_value, new_dtype, axis) + return reindexed + + +__all__ = [ + "reindex_", + "reindex_numpy", + "reindex_pydata_sparse_coo", +] diff --git a/flox/scan.py b/flox/scan.py new file mode 100644 index 000000000..d4da062ac --- /dev/null +++ b/flox/scan.py @@ -0,0 +1,325 @@ +"""Scan operations for groupby reductions.""" + +from __future__ import annotations + +import copy +from functools import partial +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from .aggregations import AlignedArrays, Scan, ScanState + from .types import ( + DaskArray, + T_Axes, + T_By, + T_Bys, + T_EngineOpt, + T_ExpectedGroupsOpt, + T_MethodOpt, + T_Scan, + ) + + +def groupby_scan( + array: np.ndarray | DaskArray, + *by: T_By, + func: T_Scan, + expected_groups: T_ExpectedGroupsOpt = None, + axis: int | tuple[int] = -1, + dtype: np.typing.DTypeLike = None, + method: T_MethodOpt = None, + engine: T_EngineOpt = None, +) -> np.ndarray | DaskArray: + """ + GroupBy reductions using parallel scans for dask.array + + Parameters + ---------- + array : ndarray or DaskArray + Array to be reduced, possibly nD + *by : ndarray or DaskArray + Array of labels to group over. Must be aligned with ``array`` so that + ``array.shape[-by.ndim :] == by.shape`` or any disagreements in that + equality check are for dimensions of size 1 in `by`. + func : {"nancumsum", "ffill", "bfill"} or Scan + Single function name or a Scan instance + expected_groups : (optional) Sequence + Expected unique labels. + axis : None or int or Sequence[int], optional + If None, reduce across all dimensions of by + Else, reduce across corresponding axes of array + Negative integers are normalized using array.ndim. + fill_value : Any + Value to assign when a label in ``expected_groups`` is not present. + dtype : data-type , optional + DType for the output. Can be anything that is accepted by ``np.dtype``. + method : {"blockwise", "cohorts"}, optional + Strategy for reduction of dask arrays only: + * ``"blockwise"``: + Only scan using blockwise and avoid aggregating blocks + together. Useful for resampling-style groupby problems where group + members are always together. If `by` is 1D, `array` is automatically + rechunked so that chunk boundaries line up with group boundaries + i.e. each block contains all members of any group present + in that block. For nD `by`, you must make sure that all members of a group + are present in a single block. + * ``"cohorts"``: + Finds group labels that tend to occur together ("cohorts"), + indexes out cohorts and reduces that subset using "map-reduce", + repeat for all cohorts. This works well for many time groupings + where the group labels repeat at regular intervals like 'hour', + 'month', dayofyear' etc. Optimize chunking ``array`` for this + method by first rechunking using ``rechunk_for_cohorts`` + (for 1D ``by`` only). + engine : {"flox", "numpy", "numba", "numbagg"}, optional + Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: + * ``"numpy"``: + Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``. + This is the default choice because it works for most array types. + * ``"flox"``: + Use an internal implementation where the data is sorted so that + all members of a group occur sequentially, and then numpy.ufunc.reduceat + is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy`` + for a reduction that is not yet implemented. + * ``"numba"``: + Use the implementations in ``numpy_groupies.aggregate_numba``. + * ``"numbagg"``: + Use the reductions supported by ``numbagg.grouped``. This will fall back to ``numpy_groupies.aggregate_numpy`` + for a reduction that is not yet implemented. + + Returns + ------- + result + Aggregated result + + See Also + -------- + xarray.xarray_reduce + """ + # Import here to avoid circular imports + from . import xrdtypes + from .aggregations import AGGREGATIONS, AlignedArrays, Scan, _atleast_1d + from .core import ( + _convert_expected_groups_to_index, + _factorize_multiple, + _validate_expected_groups, + is_duck_array, + is_duck_dask_array, + normalize_axis_tuple, + ) + + axis = _atleast_1d(axis) + if len(axis) > 1: + raise NotImplementedError("Scans are only supported along a single dimension.") + + bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) + nby = len(by) + by_is_dask = tuple(is_duck_dask_array(b) for b in bys) + any_by_dask = any(by_is_dask) + + axis_ = normalize_axis_tuple(axis, array.ndim) + + if engine is not None: + raise NotImplementedError("Setting `engine` is not supported for scans yet.") + if method is not None: + raise NotImplementedError("Setting `method` is not supported for scans yet.") + if engine is None: + engine = "flox" + assert engine == "flox" + + if not is_duck_array(array): + array = np.asarray(array) + + if isinstance(func, str): + agg = AGGREGATIONS[func] + assert isinstance(agg, Scan) + agg = copy.deepcopy(agg) + + if (agg == AGGREGATIONS["ffill"] or agg == AGGREGATIONS["bfill"]) and array.dtype.kind != "f": + # nothing to do, no NaNs! + return array + + if expected_groups is not None: + raise NotImplementedError("Setting `expected_groups` and binning is not supported yet.") + expected_groups = _validate_expected_groups(nby, expected_groups) + expected_groups = _convert_expected_groups_to_index(expected_groups, isbin=(False,) * nby, sort=False) + + # Don't factorize early only when + # grouping by dask arrays, and not having expected_groups + factorize_early = not ( + # can't do it if we are grouping by dask array but don't have expected_groups + any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups)) + ) + if factorize_early: + bys, final_groups, grp_shape = _factorize_multiple( + bys, + expected_groups, + any_by_dask=any_by_dask, + sort=False, + ) + else: + raise NotImplementedError + + assert len(bys) == 1 + by_: np.ndarray + (by_,) = bys + has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) + + if array.dtype.kind in "Mm": + cast_to = array.dtype + array = array.view(np.int64) + elif array.dtype.kind == "b": + array = array.view(np.int8) + cast_to = None + if agg.preserves_dtype: + cast_to = bool + else: + cast_to = None + + # TODO: move to aggregate_npg.py + if agg.name in ["cumsum", "nancumsum"] and array.dtype.kind in ["i", "u"]: + # https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html + # it defaults to the dtype of a, unless a + # has an integer dtype with a precision less than that of the default platform integer. + if array.dtype.kind == "i": + agg.dtype = np.result_type(array.dtype, np.int_) + elif array.dtype.kind == "u": + agg.dtype = np.result_type(array.dtype, np.uint) + else: + agg.dtype = array.dtype if dtype is None else dtype + agg.identity = xrdtypes._get_fill_value(agg.dtype, agg.identity) + + (single_axis,) = axis_ # type: ignore[misc] + # avoid some roundoff error when we can. + if by_.shape[-1] == 1 or by_.shape == grp_shape: + array = array.astype(agg.dtype) + if cast_to is not None: + array = array.astype(cast_to) + return array + + # Made a design choice here to have `preprocess` handle both array and group_idx + # Example: for reversing, we need to reverse the whole array, not just reverse + # each block independently + inp = AlignedArrays(array=array, group_idx=by_) + if agg.preprocess: + inp = agg.preprocess(inp) + + if not has_dask: + final_state = chunk_scan(inp, axis=single_axis, agg=agg, dtype=agg.dtype) + result = _finalize_scan(final_state, dtype=agg.dtype) + else: + result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg) + + # Made a design choice here to have `postprocess` handle both array and group_idx + out = AlignedArrays(array=result, group_idx=by_) + if agg.finalize: + out = agg.finalize(out) + + if cast_to is not None: + return out.array.astype(cast_to) + return out.array + + +def chunk_scan(inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims=None) -> ScanState: + from .aggregations import AlignedArrays, ScanState, generic_aggregate + + assert axis == inp.array.ndim - 1 + + # I don't think we need to re-factorize here unless we are grouping by a dask array + accumulated = generic_aggregate( + inp.group_idx, + inp.array, + axis=axis, + engine="flox", + func=agg.scan, + dtype=dtype, + fill_value=agg.identity, + ) + result = AlignedArrays(array=accumulated, group_idx=inp.group_idx) + return ScanState(result=result, state=None) + + +def grouped_reduce(inp: AlignedArrays, *, agg: Scan, axis: int, keepdims=None) -> ScanState: + from .aggregations import AlignedArrays, ScanState + from .core import chunk_reduce + + assert axis == inp.array.ndim - 1 + reduced = chunk_reduce( + inp.array, + inp.group_idx, + func=(agg.reduction,), + axis=axis, + engine="flox", + dtype=inp.array.dtype, + fill_value=agg.identity, + expected_groups=None, + ) + return ScanState( + state=AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"]), + result=None, + ) + + +def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays: + from .aggregations import AlignedArrays + + return AlignedArrays(group_idx=group_idx, array=array) + + +def _finalize_scan(block: ScanState, dtype) -> np.ndarray: + assert block.result is not None + return block.result.array.astype(dtype, copy=False) + + +def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray: + from dask.array import map_blocks + from dask.array.reductions import cumreduction as scan + + from flox.aggregations import scan_binary_op + + from .core import _unify_chunks + + if len(axes) > 1: + raise NotImplementedError("Scans are only supported along a single axis.") + (axis,) = axes + + array, by = _unify_chunks(array, by) + + # 1. zip together group indices & array + zipped = map_blocks( + _zip, + by, + array, + dtype=array.dtype, + meta=array._meta, + name="groupby-scan-preprocess", + ) + + scan_ = partial(chunk_scan, agg=agg) + # dask tokenizing error workaround + scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined] + + # 2. Run the scan + accumulated = scan( + func=scan_, + binop=partial(scan_binary_op, agg=agg), + ident=agg.identity, + x=zipped, + axis=axis, + # TODO: support method="sequential" here. + method="blelloch", + preop=partial(grouped_reduce, agg=agg), + dtype=agg.dtype, + ) + + # 3. Unzip and extract the final result array, discard groups + result = map_blocks(partial(_finalize_scan, dtype=agg.dtype), accumulated, dtype=agg.dtype) + + assert result.chunks == array.chunks + + return result + + +__all__ = ["groupby_scan"] diff --git a/flox/types.py b/flox/types.py index c3bb32c22..763e7106a 100644 --- a/flox/types.py +++ b/flox/types.py @@ -1,4 +1,7 @@ -from typing import Any, TypeAlias +from typing import TYPE_CHECKING, Any, Literal, TypeAlias + +if TYPE_CHECKING: + import numpy as np try: import cubed.Array as CubedArray @@ -11,3 +14,49 @@ except ImportError: DaskArray = Any Graph: TypeAlias = Any # type: ignore[no-redef,misc] + +# Only define these types when type checking to avoid import issues +if TYPE_CHECKING: + # Core array types + T_DuckArray: TypeAlias = "np.ndarray | DaskArray | CubedArray" + T_By: TypeAlias = T_DuckArray + T_Bys = "tuple[T_By, ...]" + + # Expected groups types + T_ExpectIndex = "pd.Index" + T_ExpectIndexTuple = "tuple[T_ExpectIndex, ...]" + T_ExpectIndexOpt = "T_ExpectIndex | None" + T_ExpectIndexOptTuple = "tuple[T_ExpectIndexOpt, ...]" + T_Expect = "Sequence | np.ndarray | T_ExpectIndex" + T_ExpectTuple = "tuple[T_Expect, ...]" + T_ExpectOpt = "Sequence | np.ndarray | T_ExpectIndexOpt" + T_ExpectOptTuple = "tuple[T_ExpectOpt, ...]" + T_ExpectedGroups = "T_Expect | T_ExpectOptTuple" + T_ExpectedGroupsOpt = "T_ExpectedGroups | None" + + # Function and aggregation types + T_Func = "str | Callable" + T_Funcs = "T_Func | Sequence[T_Func]" + T_Agg = "str" # Will be "str | Aggregation" but avoiding circular import + T_Scan = "str" # Will be "str | Scan" but avoiding circular import + + # Axis types + T_Axis = int + T_Axes = "tuple[T_Axis, ...]" + T_AxesOpt = "T_Axis | T_Axes | None" + + # Data types + T_Dtypes = "np.typing.DTypeLike | Sequence[np.typing.DTypeLike] | None" + T_FillValues = "np.typing.ArrayLike | Sequence[np.typing.ArrayLike] | None" + + # Engine and method types + T_Engine = Literal["flox", "numpy", "numba", "numbagg"] + T_EngineOpt = "None | T_Engine" + T_Method = Literal["map-reduce", "blockwise", "cohorts"] + T_MethodOpt = "None | Literal['map-reduce', 'blockwise', 'cohorts']" + + # Binning types + T_IsBins = "bool | Sequence[bool]" + + # Factorize types + FactorProps = "namedtuple('FactorProps', 'offset_group nan_sentinel nanmask')" diff --git a/flox/utils.py b/flox/utils.py new file mode 100644 index 000000000..4fac4550f --- /dev/null +++ b/flox/utils.py @@ -0,0 +1,40 @@ +"""Utility classes and functions for flox.""" + +from __future__ import annotations + +from enum import Enum, auto + +import numpy as np + + +class ReindexArrayType(Enum): + """ + Enum describing which array type to reindex to. + + These are enumerated, rather than accepting a constructor, + because we might want to optimize for specific array types, + and because they don't necessarily have the same signature. + + For example, scipy.sparse.COO only supports a fill_value of 0. + """ + + AUTO = auto() + NUMPY = auto() + SPARSE_COO = auto() + # Sadly, scipy.sparse.coo_array only supports fill_value = 0 + # SCIPY_SPARSE_COO = auto() + # SPARSE_GCXS = auto() + + def is_same_type(self, other) -> bool: + match self: + case ReindexArrayType.AUTO: + return True + case ReindexArrayType.NUMPY: + return isinstance(other, np.ndarray) + case ReindexArrayType.SPARSE_COO: + import sparse + + return isinstance(other, sparse.COO) + + +__all__ = ["ReindexArrayType"] diff --git a/tests/test_core.py b/tests/test_core.py index 7499bf996..505f0cd28 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -27,14 +27,14 @@ _is_sparse_supported_reduction, _normalize_indexes, _validate_reindex, - factorize_, find_group_cohorts, groupby_reduce, - groupby_scan, rechunk_for_cohorts, - reindex_, subset_to_blocks, ) +from flox.factorize import factorize_ +from flox.reindex import reindex_ +from flox.scan import groupby_scan from . import ( ALL_FUNCS, @@ -2148,7 +2148,7 @@ def test_reindex_sparse(size): ReindexStrategy(blockwise=True, array_type=ReindexArrayType.SPARSE_COO) reindex = ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO) - original_reindex = flox.core.reindex_ + original_reindex = flox.reindex.reindex_ def mocked_reindex(*args, **kwargs): res = original_reindex(*args, **kwargs) @@ -2162,7 +2162,7 @@ def mocked_reindex(*args, **kwargs): def raise_error(self): raise AttributeError("Access to '_data' is not allowed.") - with patch("flox.core.reindex_") as mocked_reindex_func: + with patch("flox.reindex.reindex_") as mocked_reindex_func: with patch.object(pd.RangeIndex, "_data", property(raise_error)): mocked_reindex_func.side_effect = mocked_reindex actual, *_ = groupby_reduce( diff --git a/tests/test_properties.py b/tests/test_properties.py index a1b105116..fa0b7204f 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -17,8 +17,9 @@ from hypothesis import assume, given, note, settings import flox -from flox.core import _is_sparse_supported_reduction, groupby_reduce, groupby_scan +from flox.core import _is_sparse_supported_reduction, groupby_reduce from flox.lib import sparse_array_type +from flox.scan import groupby_scan from flox.xrutils import ( _contains_cftime_datetimes, _to_pytimedelta, From 13b50f8a8ddff6145c807e7be2b6e6d77a93b4b8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 16 Jul 2025 15:52:15 -0600 Subject: [PATCH 2/3] refactor: extract dask-specific functions to dask.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move dask_groupby_agg, dask_groupby_scan, _grouped_combine, and _unify_chunks to flox/dask.py - Update imports in core.py and scan.py to use new dask module locations - Keep reindex_intermediates in core.py as requested - Maintain backward compatibility and library functionality 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- flox/core.py | 665 +-------------------------------------------- flox/cubed.py | 165 +++++++++++ flox/dask.py | 633 ++++++++++++++++++++++++++++++++++++++++++ flox/scan.py | 53 +--- tests/test_core.py | 3 +- 5 files changed, 807 insertions(+), 712 deletions(-) create mode 100644 flox/cubed.py create mode 100644 flox/dask.py diff --git a/flox/core.py b/flox/core.py index c69f275a8..ae7d3cdda 100644 --- a/flox/core.py +++ b/flox/core.py @@ -11,7 +11,6 @@ from dataclasses import dataclass from functools import partial from itertools import product -from numbers import Integral from typing import ( TYPE_CHECKING, Any, @@ -39,14 +38,13 @@ ) from .cache import memoize from .factorize import _factorize_multiple, factorize_ -from .lib import ArrayLayer, dask_array_type, sparse_array_type +from .lib import dask_array_type, sparse_array_type from .reindex import reindex_ from .utils import ReindexArrayType from .xrutils import ( _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric, - is_chunked_array, is_duck_array, is_duck_cubed_array, is_duck_dask_array, @@ -64,7 +62,7 @@ HAS_SPARSE = module_available("sparse") if TYPE_CHECKING: - from .types import CubedArray, DaskArray, Graph + from .types import CubedArray, DaskArray T_DuckArray: TypeAlias = np.ndarray | DaskArray | CubedArray # Any ? T_By: TypeAlias = T_DuckArray @@ -1116,126 +1114,6 @@ def listify_groups(x: IntermediateDict): return list(np.atleast_1d(x["groups"].squeeze())) -def _grouped_combine( - x_chunk, - agg: Aggregation, - axis: T_Axes, - keepdims: bool, - engine: T_Engine, - is_aggregate: bool = False, - sort: bool = True, -) -> IntermediateDict: - """Combine intermediates step of tree reduction.""" - from dask.utils import deepmap - - combine = agg.combine - - if isinstance(x_chunk, dict): - # Only one block at final step; skip one extra groupby - return x_chunk - - if len(axis) != 1: - # when there's only a single axis of reduction, we can just concatenate later, - # reindexing is unnecessary - # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated - unique_groups = _find_unique_groups(x_chunk) - x_chunk = deepmap( - partial( - reindex_intermediates, agg=agg, unique_groups=unique_groups, array_type=ReindexArrayType.AUTO - ), - x_chunk, - ) - - # these are negative axis indices useful for concatenating the intermediates - neg_axis = tuple(range(-len(axis), 0)) - - groups = _conc2(x_chunk, "groups", axis=neg_axis) - - if agg.reduction_type == "argreduce": - # If "nanlen" was added for masking later, we need to account for that - if agg.chunk[-1] == "nanlen": - slicer = slice(None, -1) - else: - slicer = slice(None, None) - - # We need to send the intermediate array values & indexes at the same time - # intermediates are (value e.g. max, index e.g. argmax, counts) - array_idx = tuple(_conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1)) - - # for a single element along axis, we don't want to run the argreduction twice - # This happens when we are reducing along an axis with a single chunk. - avoid_reduction = array_idx[0].shape[axis[0]] == 1 - if avoid_reduction: - results: IntermediateDict = { - "groups": groups, - "intermediates": list(array_idx), - } - else: - results = chunk_argreduce( - array_idx, - groups, - # count gets treated specially next - func=combine[slicer], # type: ignore[arg-type] - axis=axis, - expected_groups=None, - fill_value=agg.fill_value["intermediate"][slicer], - dtype=agg.dtype["intermediate"][slicer], - engine=engine, - sort=sort, - ) - - if agg.chunk[-1] == "nanlen": - counts = _conc2(x_chunk, key1="intermediates", key2=2, axis=axis) - - if avoid_reduction: - results["intermediates"].append(counts) - else: - # sum the counts - results["intermediates"].append( - chunk_reduce( - counts, - groups, - func="sum", - axis=axis, - expected_groups=None, - fill_value=(0,), - dtype=(np.intp,), - engine=engine, - sort=sort, - user_dtype=agg.dtype["user"], - )["intermediates"][0] - ) - - elif agg.reduction_type == "reduce": - # Here we reduce the intermediates individually - results = {"groups": None, "intermediates": []} - for idx, (combine_, fv, dtype) in enumerate( - zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"]) - ): - assert combine_ is not None - array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) - if array.shape[-1] == 0: - # all empty when combined - results["intermediates"].append(np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype)) - results["groups"] = np.empty(shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype) - else: - _results = chunk_reduce( - array, - groups, - func=combine_, - axis=axis, - expected_groups=None, - fill_value=(fv,), - dtype=(dtype,), - engine=engine, - sort=sort, - user_dtype=agg.dtype["user"], - ) - results["intermediates"].append(*_results["intermediates"]) - results["groups"] = _results["groups"] - return results - - def _reduce_blockwise( array, by, @@ -1285,541 +1163,6 @@ def _reduce_blockwise( return result -def _normalize_indexes(ndim: int, flatblocks: Sequence[int], blkshape: tuple[int, ...]) -> tuple: - """ - .blocks accessor can only accept one iterable at a time, - but can handle multiple slices. - To minimize tasks and layers, we normalize to produce slices - along as many axes as possible, and then repeatedly apply - any remaining iterables in a loop. - - TODO: move this upstream - """ - unraveled = np.unravel_index(flatblocks, blkshape) - - normalized: list[int | slice | list[int]] = [] - for ax, idx in enumerate(unraveled): - i = _unique(idx).squeeze() - if i.ndim == 0: - normalized.append(i.item()) - else: - if len(i) == blkshape[ax] and np.array_equal(i, np.arange(blkshape[ax])): - normalized.append(slice(None)) - elif _issorted(i) and np.array_equal(i, np.arange(i[0], i[-1] + 1)): - start = None if i[0] == 0 else i[0] - stop = i[-1] + 1 - stop = None if stop == blkshape[ax] else stop - normalized.append(slice(start, stop)) - else: - normalized.append(list(i)) - full_normalized = (slice(None),) * (ndim - len(normalized)) + tuple(normalized) - - # has no iterables - noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized) - # has all iterables - alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")} - - mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values()))) # type: ignore[arg-type, var-annotated] - - full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter)) - - return full_tuple - - -def subset_to_blocks( - array: DaskArray, - flatblocks: Sequence[int], - blkshape: tuple[int, ...] | None = None, - reindexer=identity, - chunks_as_array: tuple[np.ndarray, ...] | None = None, -) -> ArrayLayer: - """ - Advanced indexing of .blocks such that we always get a regular array back. - - Parameters - ---------- - array : dask.array - flatblocks : flat indices of blocks to extract - blkshape : shape of blocks with which to unravel flatblocks - - Returns - ------- - dask.array - """ - from dask.base import tokenize - - if blkshape is None: - blkshape = array.blocks.shape - - if chunks_as_array is None: - chunks_as_array = tuple(np.array(c) for c in array.chunks) - - index = _normalize_indexes(array.ndim, flatblocks, blkshape) - - # These rest is copied from dask.array.core.py with slight modifications - index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index) - - name = "groupby-cohort-" + tokenize(array, index) - new_keys = array._key_array[index] - - squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index) - chunks = tuple(tuple(c[i].tolist()) for c, i in zip(chunks_as_array, squeezed)) - - keys = itertools.product(*(range(len(c)) for c in chunks)) - layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys} - return ArrayLayer(layer=layer, chunks=chunks, name=name) - - -def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]: - import dask.array - from dask.highlevelgraph import HighLevelGraph - - groups_token = f"group-{reduced.name}" - first_block = reduced.ndim * (0,) - layer: Graph = {(groups_token, 0): (operator.getitem, (reduced.name, *first_block), "groups")} - groups: tuple[DaskArray] = ( - dask.array.Array( - HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]), - groups_token, - chunks=((np.nan,),), - meta=np.array([], dtype=dtype), - ), - ) - - return groups - - -def _unify_chunks(array, by): - from dask.array import from_array, unify_chunks - - inds = tuple(range(array.ndim)) - - # Unifying chunks is necessary for argreductions. - # We need to rechunk before zipping up with the index - # let's always do it anyway - if not is_duck_dask_array(by): - # chunk numpy arrays like the input array - # This removes an extra rechunk-merge layer that would be - # added otherwise - chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) - - by = from_array(by, chunks=chunks) - _, (array, by) = unify_chunks(array, inds, by, inds[-by.ndim :]) - - return array, by - - -def dask_groupby_agg( - array: DaskArray, - by: T_By, - *, - agg: Aggregation, - expected_groups: pd.RangeIndex | None, - reindex: ReindexStrategy, - axis: T_Axes = (), - fill_value: Any = None, - method: T_Method = "map-reduce", - engine: T_Engine = "numpy", - sort: bool = True, - chunks_cohorts=None, -) -> tuple[DaskArray, tuple[pd.Index | np.ndarray | DaskArray]]: - import dask.array - from dask.array.core import slices_from_chunks - from dask.highlevelgraph import HighLevelGraph - - from .dask_array_ops import _tree_reduce - - # I think _tree_reduce expects this - assert isinstance(axis, Sequence) - assert all(ax >= 0 for ax in axis) - - inds = tuple(range(array.ndim)) - name = f"groupby_{agg.name}" - - if expected_groups is None and reindex.blockwise: - raise ValueError("reindex.blockwise must be False-y if expected_groups is not provided.") - if method == "cohorts" and reindex.blockwise: - raise ValueError("reindex.blockwise must be False-y if method is 'cohorts'.") - - by_input = by - - array, by = _unify_chunks(array, by) - - # tokenize here since by has already been hashed if its numpy - token = dask.base.tokenize(array, by, agg, expected_groups, axis, method) - - # preprocess the array: - # - for argreductions, this zips the index together with the array block - # - not necessary for blockwise with argreductions - # - if this is needed later, we can fix this then - if agg.preprocess and method != "blockwise": - array = agg.preprocess(array, axis=axis) - - # 1. We first apply the groupby-reduction blockwise to generate "intermediates" - # 2. These intermediate results are combined to generate the final result using a - # "map-reduce" or "tree reduction" approach. - # There are two ways: - # a. "_simple_combine": Where it makes sense, we tree-reduce the reduction, - # NOT the groupby-reduction for a speed boost. This is what xhistogram does (effectively), - # It requires that all blocks contain all groups after the initial blockwise step (1) i.e. - # reindex.blockwise=True, and we must know expected_groups - # b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction. - # This allows us to discover groups at compute time, support argreductions, lower intermediate - # memory usage (but method="cohorts" would also work to reduce memory in some cases) - labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None - do_grouped_combine = ( - _is_arg_reduction(agg) - or labels_are_unknown - or (_is_first_last_reduction(agg) and array.dtype.kind != "f") - ) - do_simple_combine = not do_grouped_combine - - if method == "blockwise": - # use the "non dask" code path, but applied blockwise - blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex) - else: - # choose `chunk_reduce` or `chunk_argreduce` - blockwise_method = partial( - _get_chunk_reduction(agg.reduction_type), - func=agg.chunk, - reindex=reindex.blockwise, - fill_value=agg.fill_value["intermediate"], - dtype=agg.dtype["intermediate"], - user_dtype=agg.dtype["user"], - ) - if do_simple_combine: - # Add a dummy dimension that then gets reduced over - blockwise_method = tlz.compose(_expand_dims, blockwise_method) - - # apply reduction on chunk - intermediate = dask.array.blockwise( - partial( - blockwise_method, - axis=axis, - expected_groups=expected_groups if reindex.blockwise else None, - engine=engine, - sort=sort, - ), - # output indices are the same as input indices - # Unlike xhistogram, we don't always know what the size of the group - # dimension will be unless reindex=True - inds, - array, - inds, - by, - inds[-by.ndim :], - concatenate=False, - dtype=array.dtype, # this is purely for show - meta=array._meta, - align_arrays=False, - name=f"{name}-chunk-{token}", - ) - - group_chunks: tuple[tuple[int | float, ...]] - - if method in ["map-reduce", "cohorts"]: - combine: Callable[..., IntermediateDict] = ( - partial(_simple_combine, reindex=reindex) - if do_simple_combine - else partial(_grouped_combine, engine=engine, sort=sort) - ) - - tree_reduce = partial( - dask.array.reductions._tree_reduce, - name=f"{name}-simple-reduce", - dtype=array.dtype, - axis=axis, - keepdims=True, - concatenate=False, - ) - aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex) - - # Each chunk of `reduced`` is really a dict mapping - # 1. reduction name to array - # 2. "groups" to an array of group labels - # Note: it does not make sense to interpret axis relative to - # shape of intermediate results after the blockwise call - if method == "map-reduce": - reduced = tree_reduce( - intermediate, - combine=partial(combine, agg=agg), - aggregate=partial(aggregate, expected_groups=expected_groups), - ) - if labels_are_unknown: - groups = _extract_unknown_groups(reduced, dtype=by.dtype) - group_chunks = ((np.nan,),) - else: - assert expected_groups is not None - groups = (expected_groups,) - group_chunks = ((len(expected_groups),),) - - elif method == "cohorts": - assert chunks_cohorts - block_shape = array.blocks.shape[-len(axis) :] - - out_name = f"{name}-reduce-{method}-{token}" - groups_ = [] - chunks_as_array = tuple(np.array(c) for c in array.chunks) - dsk: Graph = {} - for icohort, (blks, cohort) in enumerate(chunks_cohorts.items()): - cohort_index = pd.Index(cohort) - reindexer = ( - partial( - reindex_intermediates, - agg=agg, - unique_groups=cohort_index, - array_type=reindex.array_type, - ) - if do_simple_combine - else identity - ) - subset = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array) - dsk |= subset.layer # type: ignore[operator] - # now that we have reindexed, we can set reindex=True explicitlly - new_reindex = ReindexStrategy(blockwise=do_simple_combine, array_type=reindex.array_type) - _tree_reduce( - subset, - out_dsk=dsk, - name=out_name, - block_index=icohort, - axis=axis, - combine=partial(combine, agg=agg, reindex=new_reindex, keepdims=True), - aggregate=partial( - aggregate, expected_groups=cohort_index, reindex=new_reindex, keepdims=True - ), - ) - # This is done because pandas promotes to 64-bit types when an Index is created - # So we use the index to generate the return value for consistency with "map-reduce" - # This is important on windows - groups_.append(cohort_index.values) - - graph = HighLevelGraph.from_collections(out_name, dsk, dependencies=[intermediate]) - - out_chunks = list(array.chunks) - out_chunks[axis[-1]] = tuple(len(c) for c in chunks_cohorts.values()) - for ax in axis[:-1]: - out_chunks[ax] = (1,) - reduced = dask.array.Array(graph, out_name, out_chunks, meta=array._meta) - - groups = (np.concatenate(groups_),) - group_chunks = (tuple(len(cohort) for cohort in groups_),) - - elif method == "blockwise": - reduced = intermediate - if reindex.blockwise: - if TYPE_CHECKING: - assert expected_groups is not None - # TODO: we could have `expected_groups` be a dask array with appropriate chunks - # for now, we have a numpy array that is interpreted as listing all group labels - # that are present in every chunk - groups = (expected_groups,) - group_chunks = ((len(expected_groups),),) - else: - # TODO: use chunks_cohorts here; hard because chunks_cohorts does not include all-NaN blocks - # but the array after applying the blockwise op; does. We'd have to insert a subsetting op. - # Here one input chunk → one output chunks - # find number of groups in each chunk, this is needed for output chunks - # along the reduced axis - # TODO: this logic is very specialized for the resampling case - slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) - groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) - groups = (np.concatenate(groups_in_block),) - ngroups_per_block = tuple(len(grp) for grp in groups_in_block) - group_chunks = (ngroups_per_block,) - else: - raise ValueError(f"Unknown method={method}.") - - # Adjust output for any new dimensions added, example for multiple quantiles - new_dims_shape = tuple(dim.size for dim in agg.new_dims if not dim.is_scalar) - new_inds = tuple(range(-len(new_dims_shape), 0)) - out_inds = new_inds + inds[: -len(axis)] + (inds[-1],) - output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks - new_axes = dict(zip(new_inds, new_dims_shape)) - - if method == "blockwise" and len(axis) > 1: - # The final results are available but the blocks along axes - # need to be reshaped to axis=-1 - # I don't know that this is possible with blockwise - # All other code paths benefit from an unmaterialized Blockwise layer - reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks) - - # Can't use map_blocks because it forces concatenate=True along drop_axes, - result = dask.array.blockwise( - _extract_result, - out_inds, - reduced, - inds, - adjust_chunks=dict(zip(out_inds, output_chunks)), - key=agg.name, - name=f"{name}-{token}", - concatenate=False, - new_axes=new_axes, - meta=reindex.get_dask_meta(array, dtype=agg.dtype["final"], fill_value=agg.fill_value[agg.name]), - ) - - return (result, groups) - - -def cubed_groupby_agg( - array: CubedArray, - by: T_By, - agg: Aggregation, - expected_groups: pd.Index | None, - reindex: ReindexStrategy, - axis: T_Axes = (), - fill_value: Any = None, - method: T_Method = "map-reduce", - engine: T_Engine = "numpy", - sort: bool = True, - chunks_cohorts=None, -) -> tuple[CubedArray, tuple[pd.Index | np.ndarray | CubedArray]]: - import cubed - import cubed.core.groupby - - # I think _tree_reduce expects this - assert isinstance(axis, Sequence) - assert all(ax >= 0 for ax in axis) - - if method == "blockwise": - assert by.ndim == 1 - assert expected_groups is not None - - def _reduction_func(a, by, axis, start_group, num_groups): - # adjust group labels to start from 0 for each chunk - by_for_chunk = by - start_group - expected_groups_for_chunk = pd.RangeIndex(num_groups) - - axis = (axis,) # convert integral axis to tuple - - blockwise_method = partial( - _reduce_blockwise, - agg=agg, - axis=axis, - expected_groups=expected_groups_for_chunk, - fill_value=fill_value, - engine=engine, - sort=sort, - reindex=reindex, - ) - out = blockwise_method(a, by_for_chunk) - return out[agg.name] - - num_groups = len(expected_groups) - result = cubed.core.groupby.groupby_blockwise( - array, by, axis=axis, func=_reduction_func, num_groups=num_groups - ) - groups = (expected_groups,) - return (result, groups) - - else: - inds = tuple(range(array.ndim)) - - by_input = by - - # Unifying chunks is necessary for argreductions. - # We need to rechunk before zipping up with the index - # let's always do it anyway - if not is_chunked_array(by): - # chunk numpy arrays like the input array - chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) - - by = cubed.from_array(by, chunks=chunks, spec=array.spec) - _, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :]) - - # Cubed's groupby_reduction handles the generation of "intermediates", and the - # "map-reduce" combination step, so we don't have to do that here. - # Only the equivalent of "_simple_combine" is supported, there is no - # support for "_grouped_combine". - labels_are_unknown = is_chunked_array(by_input) and expected_groups is None - do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown - - assert do_simple_combine - assert method == "map-reduce" - assert expected_groups is not None - assert reindex.blockwise is True - assert len(axis) == 1 # one axis/grouping - - def _groupby_func(a, by, axis, intermediate_dtype, num_groups): - blockwise_method = partial( - _get_chunk_reduction(agg.reduction_type), - func=agg.chunk, - fill_value=agg.fill_value["intermediate"], - dtype=agg.dtype["intermediate"], - reindex=reindex, - user_dtype=agg.dtype["user"], - axis=axis, - expected_groups=expected_groups, - engine=engine, - sort=sort, - ) - out = blockwise_method(a, by) - # Convert dict to one that cubed understands, dropping groups since they are - # known, and the same for every block. - return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])} - - def _groupby_combine(a, axis, dummy_axis, dtype, keepdims): - # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed - # only combine over the dummy axis, to preserve grouping along 'axis' - dtype = dict(dtype) - out = {} - for idx, combine in enumerate(agg.simple_combine): - field = f"f{idx}" - out[field] = combine(a[field], axis=dummy_axis, keepdims=keepdims) - return out - - def _groupby_aggregate(a, **kwargs): - # Convert cubed dict to one that _finalize_results works with - results = {"groups": expected_groups, "intermediates": a.values()} - out = _finalize_results(results, agg, axis, expected_groups, reindex) - return out[agg.name] - - # convert list of dtypes to a structured dtype for cubed - intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])] - dtype = agg.dtype["final"] - num_groups = len(expected_groups) - - result = cubed.core.groupby.groupby_reduction( - array, - by, - func=_groupby_func, - combine_func=_groupby_combine, - aggregate_func=_groupby_aggregate, - axis=axis, - intermediate_dtype=intermediate_dtype, - dtype=dtype, - num_groups=num_groups, - ) - - groups = (expected_groups,) - - return (result, groups) - - -def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray: - import dask.array - from dask.highlevelgraph import HighLevelGraph - - nblocks = tuple(reduced.numblocks[ax] for ax in axis) - output_chunks = reduced.chunks[: -len(axis)] + ((1,) * (len(axis) - 1),) + group_chunks - - # extract results from the dict - ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) - layer2: dict[tuple, tuple] = {} - name = f"reshape-{reduced.name}" - - for ochunk in itertools.product(*ochunks): - inchunk = ochunk[: -len(axis)] + np.unravel_index(ochunk[-1], nblocks) - layer2[(name, *ochunk)] = (reduced.name, *inchunk) - - layer2: Graph - return dask.array.Array( - HighLevelGraph.from_collections(name, layer2, dependencies=[reduced]), - name, - chunks=output_chunks, - dtype=reduced.dtype, - ) - - def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray: from dask.array.core import deepfirst @@ -2352,6 +1695,8 @@ def groupby_reduce( "Reduction for Cubed arrays is only implemented for methods 'map-reduce' and 'blockwise'." ) + from .cubed import cubed_groupby_agg + partial_agg = partial(cubed_groupby_agg, **kwargs) result, groups = partial_agg( @@ -2449,6 +1794,8 @@ def groupby_reduce( if kwargs["fill_value"] is None: kwargs["fill_value"] = agg.fill_value[agg.name] + from .dask import dask_groupby_agg + partial_agg = partial(dask_groupby_agg, **kwargs) # if preferred method is already blockwise, no need to rechunk diff --git a/flox/cubed.py b/flox/cubed.py new file mode 100644 index 000000000..3f10baa0a --- /dev/null +++ b/flox/cubed.py @@ -0,0 +1,165 @@ +"""Cubed-specific functions for groupby operations.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import partial +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd + +if TYPE_CHECKING: + from .aggregations import Aggregation + from .core import T_Axes, T_Engine, T_Method + from .types import CubedArray, T_By + from .utils import ReindexStrategy + +from .core import ( + _finalize_results, + _get_chunk_reduction, + _is_arg_reduction, + _reduce_blockwise, + is_chunked_array, +) +from .utils import ReindexStrategy + + +def cubed_groupby_agg( + array: CubedArray, + by: T_By, + agg: Aggregation, + expected_groups: pd.Index | None, + reindex: ReindexStrategy, + axis: T_Axes = (), + fill_value: Any = None, + method: T_Method = "map-reduce", + engine: T_Engine = "numpy", + sort: bool = True, + chunks_cohorts=None, +) -> tuple[CubedArray, tuple[pd.Index | np.ndarray | CubedArray]]: + import cubed + import cubed.core.groupby + + # I think _tree_reduce expects this + assert isinstance(axis, Sequence) + assert all(ax >= 0 for ax in axis) + + if method == "blockwise": + assert by.ndim == 1 + assert expected_groups is not None + + def _reduction_func(a, by, axis, start_group, num_groups): + # adjust group labels to start from 0 for each chunk + by_for_chunk = by - start_group + expected_groups_for_chunk = pd.RangeIndex(num_groups) + + axis = (axis,) # convert integral axis to tuple + + blockwise_method = partial( + _reduce_blockwise, + agg=agg, + axis=axis, + expected_groups=expected_groups_for_chunk, + fill_value=fill_value, + engine=engine, + sort=sort, + reindex=reindex, + ) + out = blockwise_method(a, by_for_chunk) + return out[agg.name] + + num_groups = len(expected_groups) + result = cubed.core.groupby.groupby_blockwise( + array, by, axis=axis, func=_reduction_func, num_groups=num_groups + ) + groups = (expected_groups,) + return (result, groups) + + else: + inds = tuple(range(array.ndim)) + + by_input = by + + # Unifying chunks is necessary for argreductions. + # We need to rechunk before zipping up with the index + # let's always do it anyway + if not is_chunked_array(by): + # chunk numpy arrays like the input array + chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) + + by = cubed.from_array(by, chunks=chunks, spec=array.spec) + _, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :]) + + # Cubed's groupby_reduction handles the generation of "intermediates", and the + # "map-reduce" combination step, so we don't have to do that here. + # Only the equivalent of "_simple_combine" is supported, there is no + # support for "_grouped_combine". + labels_are_unknown = is_chunked_array(by_input) and expected_groups is None + do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown + + assert do_simple_combine + assert method == "map-reduce" + assert expected_groups is not None + assert reindex.blockwise is True + assert len(axis) == 1 # one axis/grouping + + def _groupby_func(a, by, axis, intermediate_dtype, num_groups): + blockwise_method = partial( + _get_chunk_reduction(agg.reduction_type), + func=agg.chunk, + fill_value=agg.fill_value["intermediate"], + dtype=agg.dtype["intermediate"], + reindex=reindex, + user_dtype=agg.dtype["user"], + axis=axis, + expected_groups=expected_groups, + engine=engine, + sort=sort, + ) + out = blockwise_method(a, by) + # Convert dict to one that cubed understands, dropping groups since they are + # known, and the same for every block. + return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])} + + def _groupby_combine(a, axis, dummy_axis, dtype, keepdims): + # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed + # only combine over the dummy axis, to preserve grouping along 'axis' + dtype = dict(dtype) + out = {} + for idx, combine in enumerate(agg.simple_combine): + field = f"f{idx}" + out[field] = combine(a[field], axis=dummy_axis, keepdims=keepdims) + return out + + def _groupby_aggregate(a, **kwargs): + # Convert cubed dict to one that _finalize_results works with + results = {"groups": expected_groups, "intermediates": a.values()} + out = _finalize_results(results, agg, axis, expected_groups, reindex) + return out[agg.name] + + # convert list of dtypes to a structured dtype for cubed + intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])] + dtype = agg.dtype["final"] + num_groups = len(expected_groups) + + result = cubed.core.groupby.groupby_reduction( + array, + by, + func=_groupby_func, + combine_func=_groupby_combine, + aggregate_func=_groupby_aggregate, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + num_groups=num_groups, + ) + + groups = (expected_groups,) + + return (result, groups) + + +__all__ = [ + "cubed_groupby_agg", +] diff --git a/flox/dask.py b/flox/dask.py new file mode 100644 index 000000000..f8bf00555 --- /dev/null +++ b/flox/dask.py @@ -0,0 +1,633 @@ +"""Dask-specific functions for groupby operations.""" + +from __future__ import annotations + +import itertools +import operator +from collections.abc import Callable, Sequence +from functools import partial +from numbers import Integral +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd +import toolz as tlz + +if TYPE_CHECKING: + from .aggregations import Aggregation, Scan + from .core import IntermediateDict, T_Axes, T_Engine, T_Method + from .lib import ArrayLayer + from .types import DaskArray, Graph, T_By + from .utils import ReindexArrayType, ReindexStrategy + +from .core import ( + _aggregate, + _conc2, + _expand_dims, + _extract_result, + _find_unique_groups, + _get_chunk_reduction, + _is_arg_reduction, + _is_first_last_reduction, + _issorted, + _reduce_blockwise, + _simple_combine, + _unique, + chunk_argreduce, + chunk_reduce, + identity, + reindex_intermediates, +) +from .utils import ReindexArrayType, ReindexStrategy +from .xrutils import is_duck_dask_array + + +def _unify_chunks(array, by): + from dask.array import from_array, unify_chunks + + inds = tuple(range(array.ndim)) + + # Unifying chunks is necessary for argreductions. + # We need to rechunk before zipping up with the index + # let's always do it anyway + if not is_duck_dask_array(by): + # chunk numpy arrays like the input array + # This removes an extra rechunk-merge layer that would be + # added otherwise + chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) + + by = from_array(by, chunks=chunks) + _, (array, by) = unify_chunks(array, inds, by, inds[-by.ndim :]) + + return array, by + + +def _grouped_combine( + x_chunk, + agg: Aggregation, + axis: T_Axes, + keepdims: bool, + engine: T_Engine, + is_aggregate: bool = False, + sort: bool = True, +) -> IntermediateDict: + """Combine intermediates step of tree reduction.""" + from dask.utils import deepmap + + combine = agg.combine + + if isinstance(x_chunk, dict): + # Only one block at final step; skip one extra groupby + return x_chunk + + if len(axis) != 1: + # when there's only a single axis of reduction, we can just concatenate later, + # reindexing is unnecessary + # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated + unique_groups = _find_unique_groups(x_chunk) + x_chunk = deepmap( + partial( + reindex_intermediates, agg=agg, unique_groups=unique_groups, array_type=ReindexArrayType.AUTO + ), + x_chunk, + ) + + # these are negative axis indices useful for concatenating the intermediates + neg_axis = tuple(range(-len(axis), 0)) + + groups = _conc2(x_chunk, "groups", axis=neg_axis) + + if agg.reduction_type == "argreduce": + # If "nanlen" was added for masking later, we need to account for that + if agg.chunk[-1] == "nanlen": + slicer = slice(None, -1) + else: + slicer = slice(None, None) + + # We need to send the intermediate array values & indexes at the same time + # intermediates are (value e.g. max, index e.g. argmax, counts) + array_idx = tuple(_conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1)) + + # for a single element along axis, we don't want to run the argreduction twice + # This happens when we are reducing along an axis with a single chunk. + avoid_reduction = array_idx[0].shape[axis[0]] == 1 + if avoid_reduction: + results: IntermediateDict = { + "groups": groups, + "intermediates": list(array_idx), + } + else: + results = chunk_argreduce( + array_idx, + groups, + # count gets treated specially next + func=combine[slicer], # type: ignore[arg-type] + axis=axis, + expected_groups=None, + fill_value=agg.fill_value["intermediate"][slicer], + dtype=agg.dtype["intermediate"][slicer], + engine=engine, + sort=sort, + ) + + if agg.chunk[-1] == "nanlen": + counts = _conc2(x_chunk, key1="intermediates", key2=2, axis=axis) + + if avoid_reduction: + results["intermediates"].append(counts) + else: + # sum the counts + results["intermediates"].append( + chunk_reduce( + counts, + groups, + func="sum", + axis=axis, + expected_groups=None, + fill_value=(0,), + dtype=(np.intp,), + engine=engine, + sort=sort, + user_dtype=agg.dtype["user"], + )["intermediates"][0] + ) + + elif agg.reduction_type == "reduce": + # Here we reduce the intermediates individually + results = {"groups": None, "intermediates": []} + for idx, (combine_, fv, dtype) in enumerate( + zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"]) + ): + assert combine_ is not None + array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) + if array.shape[-1] == 0: + # No groups found in input data. Return to avoid a tree-reduce + # step with no data. + results["groups"] = groups + results["intermediates"].append(array) + continue + reduced = chunk_reduce( + array, + groups, + axis=axis, + func=combine_, + expected_groups=None, + fill_value=(fv,), + dtype=(dtype,), + engine=engine, + sort=sort, + user_dtype=agg.dtype["user"], + ) + # we had groups so this should've been set + if results["groups"] is None: + results["groups"] = reduced["groups"] + results["intermediates"].append(reduced["intermediates"][0]) + + # final pass and add keepdims=False. + results["groups"] = results["groups"].squeeze() if not keepdims else results["groups"] + + return results + + +def dask_groupby_agg( + array: DaskArray, + by: T_By, + *, + agg: Aggregation, + expected_groups: pd.RangeIndex | None, + reindex: ReindexStrategy, + axis: T_Axes = (), + fill_value: Any = None, + method: T_Method = "map-reduce", + engine: T_Engine = "numpy", + sort: bool = True, + chunks_cohorts=None, +) -> tuple[DaskArray, tuple[pd.Index | np.ndarray | DaskArray]]: + import dask.array + from dask.array.core import slices_from_chunks + from dask.highlevelgraph import HighLevelGraph + + from .dask_array_ops import _tree_reduce + + # I think _tree_reduce expects this + assert isinstance(axis, Sequence) + assert all(ax >= 0 for ax in axis) + + inds = tuple(range(array.ndim)) + name = f"groupby_{agg.name}" + + if expected_groups is None and reindex.blockwise: + raise ValueError("reindex.blockwise must be False-y if expected_groups is not provided.") + if method == "cohorts" and reindex.blockwise: + raise ValueError("reindex.blockwise must be False-y if method is 'cohorts'.") + + by_input = by + + array, by = _unify_chunks(array, by) + + # tokenize here since by has already been hashed if its numpy + token = dask.base.tokenize(array, by, agg, expected_groups, axis, method) + + # preprocess the array: + # - for argreductions, this zips the index together with the array block + # - not necessary for blockwise with argreductions + # - if this is needed later, we can fix this then + if agg.preprocess and method != "blockwise": + array = agg.preprocess(array, axis=axis) + + # 1. We first apply the groupby-reduction blockwise to generate "intermediates" + # 2. These intermediate results are combined to generate the final result using a + # "map-reduce" or "tree reduction" approach. + # There are two ways: + # a. "_simple_combine": Where it makes sense, we tree-reduce the reduction, + # NOT the groupby-reduction for a speed boost. This is what xhistogram does (effectively), + # It requires that all blocks contain all groups after the initial blockwise step (1) i.e. + # reindex.blockwise=True, and we must know expected_groups + # b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction. + # This allows us to discover groups at compute time, support argreductions, lower intermediate + # memory usage (but method="cohorts" would also work to reduce memory in some cases) + labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None + do_grouped_combine = ( + _is_arg_reduction(agg) + or labels_are_unknown + or (_is_first_last_reduction(agg) and array.dtype.kind != "f") + ) + do_simple_combine = not do_grouped_combine + + if method == "blockwise": + # use the "non dask" code path, but applied blockwise + blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex) + else: + # choose `chunk_reduce` or `chunk_argreduce` + blockwise_method = partial( + _get_chunk_reduction(agg.reduction_type), + func=agg.chunk, + reindex=reindex.blockwise, + fill_value=agg.fill_value["intermediate"], + dtype=agg.dtype["intermediate"], + user_dtype=agg.dtype["user"], + ) + if do_simple_combine: + # Add a dummy dimension that then gets reduced over + blockwise_method = tlz.compose(_expand_dims, blockwise_method) + + # apply reduction on chunk + intermediate = dask.array.blockwise( + partial( + blockwise_method, + axis=axis, + expected_groups=expected_groups if reindex.blockwise else None, + engine=engine, + sort=sort, + ), + # output indices are the same as input indices + # Unlike xhistogram, we don't always know what the size of the group + # dimension will be unless reindex=True + inds, + array, + inds, + by, + inds[-by.ndim :], + concatenate=False, + dtype=array.dtype, # this is purely for show + meta=array._meta, + align_arrays=False, + name=f"{name}-chunk-{token}", + ) + + group_chunks: tuple[tuple[int | float, ...]] + + if method in ["map-reduce", "cohorts"]: + combine: Callable[..., IntermediateDict] = ( + partial(_simple_combine, reindex=reindex) + if do_simple_combine + else partial(_grouped_combine, engine=engine, sort=sort) + ) + + tree_reduce = partial( + dask.array.reductions._tree_reduce, + name=f"{name}-simple-reduce", + dtype=array.dtype, + axis=axis, + keepdims=True, + concatenate=False, + ) + aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex) + + # Each chunk of `reduced`` is really a dict mapping + # 1. reduction name to array + # 2. "groups" to an array of group labels + # Note: it does not make sense to interpret axis relative to + # shape of intermediate results after the blockwise call + if method == "map-reduce": + reduced = tree_reduce( + intermediate, + combine=partial(combine, agg=agg), + aggregate=partial(aggregate, expected_groups=expected_groups), + ) + if labels_are_unknown: + groups = _extract_unknown_groups(reduced, dtype=by.dtype) + group_chunks = ((np.nan,),) + else: + assert expected_groups is not None + groups = (expected_groups,) + group_chunks = ((len(expected_groups),),) + + elif method == "cohorts": + assert chunks_cohorts + block_shape = array.blocks.shape[-len(axis) :] + + out_name = f"{name}-reduce-{method}-{token}" + groups_ = [] + chunks_as_array = tuple(np.array(c) for c in array.chunks) + dsk: Graph = {} + for icohort, (blks, cohort) in enumerate(chunks_cohorts.items()): + cohort_index = pd.Index(cohort) + reindexer = ( + partial( + reindex_intermediates, + agg=agg, + unique_groups=cohort_index, + array_type=reindex.array_type, + ) + if do_simple_combine + else identity + ) + subset = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array) + dsk |= subset.layer # type: ignore[operator] + # now that we have reindexed, we can set reindex=True explicitlly + new_reindex = ReindexStrategy(blockwise=do_simple_combine, array_type=reindex.array_type) + _tree_reduce( + subset, + out_dsk=dsk, + name=out_name, + block_index=icohort, + axis=axis, + combine=partial(combine, agg=agg, reindex=new_reindex, keepdims=True), + aggregate=partial( + aggregate, expected_groups=cohort_index, reindex=new_reindex, keepdims=True + ), + ) + # This is done because pandas promotes to 64-bit types when an Index is created + # So we use the index to generate the return value for consistency with "map-reduce" + # This is important on windows + groups_.append(cohort_index.values) + + graph = HighLevelGraph.from_collections(out_name, dsk, dependencies=[intermediate]) + + out_chunks = list(array.chunks) + out_chunks[axis[-1]] = tuple(len(c) for c in chunks_cohorts.values()) + for ax in axis[:-1]: + out_chunks[ax] = (1,) + reduced = dask.array.Array(graph, out_name, out_chunks, meta=array._meta) + + groups = (np.concatenate(groups_),) + group_chunks = (tuple(len(cohort) for cohort in groups_),) + + elif method == "blockwise": + reduced = intermediate + if reindex.blockwise: + if TYPE_CHECKING: + assert expected_groups is not None + # TODO: we could have `expected_groups` be a dask array with appropriate chunks + # for now, we have a numpy array that is interpreted as listing all group labels + # that are present in every chunk + groups = (expected_groups,) + group_chunks = ((len(expected_groups),),) + else: + # TODO: use chunks_cohorts here; hard because chunks_cohorts does not include all-NaN blocks + # but the array after applying the blockwise op; does. We'd have to insert a subsetting op. + # Here one input chunk → one output chunks + # find number of groups in each chunk, this is needed for output chunks + # along the reduced axis + # TODO: this logic is very specialized for the resampling case + slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) + groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) + groups = (np.concatenate(groups_in_block),) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) + group_chunks = (ngroups_per_block,) + else: + raise ValueError(f"Unknown method={method}.") + + # Adjust output for any new dimensions added, example for multiple quantiles + new_dims_shape = tuple(dim.size for dim in agg.new_dims if not dim.is_scalar) + new_inds = tuple(range(-len(new_dims_shape), 0)) + out_inds = new_inds + inds[: -len(axis)] + (inds[-1],) + output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks + new_axes = dict(zip(new_inds, new_dims_shape)) + + if method == "blockwise" and len(axis) > 1: + # The final results are available but the blocks along axes + # need to be reshaped to axis=-1 + # I don't know that this is possible with blockwise + # All other code paths benefit from an unmaterialized Blockwise layer + reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks) + + # Can't use map_blocks because it forces concatenate=True along drop_axes, + result = dask.array.blockwise( + _extract_result, + out_inds, + reduced, + inds, + adjust_chunks=dict(zip(out_inds, output_chunks)), + key=agg.name, + name=f"{name}-{token}", + concatenate=False, + new_axes=new_axes, + meta=reindex.get_dask_meta(array, dtype=agg.dtype["final"], fill_value=agg.fill_value[agg.name]), + ) + + return (result, groups) + + +def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray: + from dask.array import map_blocks + from dask.array.reductions import cumreduction as scan + + from .aggregations import scan_binary_op + + if len(axes) > 1: + raise NotImplementedError("Scans are only supported along a single axis.") + (axis,) = axes + + array, by = _unify_chunks(array, by) + + # Import scan-specific functions from scan module + from .scan import _finalize_scan, _zip, chunk_scan, grouped_reduce + + # 1. zip together group indices & array + zipped = map_blocks( + _zip, + by, + array, + dtype=array.dtype, + meta=array._meta, + name="groupby-scan-preprocess", + ) + + scan_ = partial(chunk_scan, agg=agg) + # dask tokenizing error workaround + scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined] + + # 2. Run the scan + accumulated = scan( + func=scan_, + binop=partial(scan_binary_op, agg=agg), + ident=agg.identity, + x=zipped, + axis=axis, + # TODO: support method="sequential" here. + method="blelloch", + preop=partial(grouped_reduce, agg=agg), + dtype=agg.dtype, + ) + + # 3. Unzip and extract the final result array, discard groups + result = map_blocks(partial(_finalize_scan, dtype=agg.dtype), accumulated, dtype=agg.dtype) + + assert result.chunks == array.chunks + + return result + + +def _normalize_indexes(ndim: int, flatblocks: Sequence[int], blkshape: tuple[int, ...]) -> tuple: + """ + .blocks accessor can only accept one iterable at a time, + but can handle multiple slices. + To minimize tasks and layers, we normalize to produce slices + along as many axes as possible, and then repeatedly apply + any remaining iterables in a loop. + + TODO: move this upstream + """ + unraveled = np.unravel_index(flatblocks, blkshape) + + normalized: list[int | slice | list[int]] = [] + for ax, idx in enumerate(unraveled): + i = _unique(idx).squeeze() + if i.ndim == 0: + normalized.append(i.item()) + else: + if len(i) == blkshape[ax] and np.array_equal(i, np.arange(blkshape[ax])): + normalized.append(slice(None)) + elif _issorted(i) and np.array_equal(i, np.arange(i[0], i[-1] + 1)): + start = None if i[0] == 0 else i[0] + stop = i[-1] + 1 + stop = None if stop == blkshape[ax] else stop + normalized.append(slice(start, stop)) + else: + normalized.append(list(i)) + full_normalized = (slice(None),) * (ndim - len(normalized)) + tuple(normalized) + + # has no iterables + noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized) + # has all iterables + alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")} + + mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values()))) # type: ignore[arg-type, var-annotated] + + full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter)) + + return full_tuple + + +def subset_to_blocks( + array: DaskArray, + flatblocks: Sequence[int], + blkshape: tuple[int, ...] | None = None, + reindexer=identity, + chunks_as_array: tuple[np.ndarray, ...] | None = None, +) -> ArrayLayer: + """ + Advanced indexing of .blocks such that we always get a regular array back. + + Parameters + ---------- + array : dask.array + flatblocks : flat indices of blocks to extract + blkshape : shape of blocks with which to unravel flatblocks + + Returns + ------- + dask.array + """ + from dask.base import tokenize + + from .lib import ArrayLayer + + if blkshape is None: + blkshape = array.blocks.shape + + if chunks_as_array is None: + chunks_as_array = tuple(np.array(c) for c in array.chunks) + + index = _normalize_indexes(array.ndim, flatblocks, blkshape) + + # These rest is copied from dask.array.core.py with slight modifications + index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index) + + name = "groupby-cohort-" + tokenize(array, index) + new_keys = array._key_array[index] + + squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index) + chunks = tuple(tuple(c[i].tolist()) for c, i in zip(chunks_as_array, squeezed)) + + keys = itertools.product(*(range(len(c)) for c in chunks)) + layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys} + return ArrayLayer(layer=layer, chunks=chunks, name=name) + + +def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]: + import dask.array + from dask.highlevelgraph import HighLevelGraph + + groups_token = f"group-{reduced.name}" + first_block = reduced.ndim * (0,) + layer: Graph = {(groups_token, 0): (operator.getitem, (reduced.name, *first_block), "groups")} + groups: tuple[DaskArray] = ( + dask.array.Array( + HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]), + groups_token, + chunks=((np.nan,),), + meta=np.array([], dtype=dtype), + ), + ) + + return groups + + +def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray: + import dask.array + from dask.highlevelgraph import HighLevelGraph + + nblocks = tuple(reduced.numblocks[ax] for ax in axis) + output_chunks = reduced.chunks[: -len(axis)] + ((1,) * (len(axis) - 1),) + group_chunks + + # extract results from the dict + ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) + layer2: dict[tuple, tuple] = {} + name = f"reshape-{reduced.name}" + + for ochunk in itertools.product(*ochunks): + inchunk = ochunk[: -len(axis)] + np.unravel_index(ochunk[-1], nblocks) + layer2[(name, *ochunk)] = (reduced.name, *inchunk) + + layer2: Graph + return dask.array.Array( + HighLevelGraph.from_collections(name, layer2, dependencies=[reduced]), + name, + chunks=output_chunks, + dtype=reduced.dtype, + ) + + +__all__ = [ + "_collapse_blocks_along_axes", + "_extract_unknown_groups", + "_grouped_combine", + "_normalize_indexes", + "_unify_chunks", + "dask_groupby_agg", + "dask_groupby_scan", + "subset_to_blocks", +] diff --git a/flox/scan.py b/flox/scan.py index d4da062ac..90999b448 100644 --- a/flox/scan.py +++ b/flox/scan.py @@ -3,7 +3,6 @@ from __future__ import annotations import copy -from functools import partial from typing import TYPE_CHECKING import numpy as np @@ -12,7 +11,6 @@ from .aggregations import AlignedArrays, Scan, ScanState from .types import ( DaskArray, - T_Axes, T_By, T_Bys, T_EngineOpt, @@ -210,6 +208,8 @@ def groupby_scan( final_state = chunk_scan(inp, axis=single_axis, agg=agg, dtype=agg.dtype) result = _finalize_scan(final_state, dtype=agg.dtype) else: + from .dask import dask_groupby_scan + result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg) # Made a design choice here to have `postprocess` handle both array and group_idx @@ -273,53 +273,4 @@ def _finalize_scan(block: ScanState, dtype) -> np.ndarray: return block.result.array.astype(dtype, copy=False) -def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray: - from dask.array import map_blocks - from dask.array.reductions import cumreduction as scan - - from flox.aggregations import scan_binary_op - - from .core import _unify_chunks - - if len(axes) > 1: - raise NotImplementedError("Scans are only supported along a single axis.") - (axis,) = axes - - array, by = _unify_chunks(array, by) - - # 1. zip together group indices & array - zipped = map_blocks( - _zip, - by, - array, - dtype=array.dtype, - meta=array._meta, - name="groupby-scan-preprocess", - ) - - scan_ = partial(chunk_scan, agg=agg) - # dask tokenizing error workaround - scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined] - - # 2. Run the scan - accumulated = scan( - func=scan_, - binop=partial(scan_binary_op, agg=agg), - ident=agg.identity, - x=zipped, - axis=axis, - # TODO: support method="sequential" here. - method="blelloch", - preop=partial(grouped_reduce, agg=agg), - dtype=agg.dtype, - ) - - # 3. Unzip and extract the final result array, discard groups - result = map_blocks(partial(_finalize_scan, dtype=agg.dtype), accumulated, dtype=agg.dtype) - - assert result.chunks == array.chunks - - return result - - __all__ = ["groupby_scan"] diff --git a/tests/test_core.py b/tests/test_core.py index 505f0cd28..64255f792 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -25,13 +25,12 @@ _convert_expected_groups_to_index, _get_optimal_chunks_for_groups, _is_sparse_supported_reduction, - _normalize_indexes, _validate_reindex, find_group_cohorts, groupby_reduce, rechunk_for_cohorts, - subset_to_blocks, ) +from flox.dask import _normalize_indexes, subset_to_blocks from flox.factorize import factorize_ from flox.reindex import reindex_ from flox.scan import groupby_scan From e657795ffe8156c111bd81ee044236391ddc5a0d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 16 Jul 2025 16:06:43 -0600 Subject: [PATCH 3/3] Move _postprocess_numbagg to aggregate_numbagg.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Relocates _postprocess_numbagg function from core.py to aggregate_numbagg.py - Updates import in core.py to use the new location - Groups numbagg-specific functionality together for better organization 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- flox/aggregate_numbagg.py | 17 +++++++++++++++++ flox/core.py | 20 ++------------------ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index de8d1468b..cd42fbd0d 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -136,6 +136,23 @@ def nanlen(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None) any = partial(_numbagg_wrapper, func="nanany") all = partial(_numbagg_wrapper, func="nanall") + +def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): + """Account for numbagg not providing a fill_value kwarg.""" + if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE: + return result + # The condition needs to be + # len(found_groups) < size; if so we mask with fill_value (?) + default_fv = DEFAULT_FILL_VALUE[func] + needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True) + groups = np.arange(size) + if needs_masking: + mask = np.isin(groups, seen_groups, assume_unique=True, invert=True) + if mask.any(): + result[..., groups[mask]] = fill_value + return result + + # sum = nansum # mean = nanmean # sum_of_squares = nansum_of_squares diff --git a/flox/core.py b/flox/core.py index ae7d3cdda..fa1156855 100644 --- a/flox/core.py +++ b/flox/core.py @@ -143,24 +143,6 @@ def get_dask_meta(self, other, *, fill_value, dtype) -> Any: return sparse.COO.from_numpy(np.ones(shape=(0,) * other.ndim, dtype=dtype), fill_value=fill_value) -def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): - """Account for numbagg not providing a fill_value kwarg.""" - from .aggregate_numbagg import DEFAULT_FILL_VALUE - - if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE: - return result - # The condition needs to be - # len(found_groups) < size; if so we mask with fill_value (?) - default_fv = DEFAULT_FILL_VALUE[func] - needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True) - groups = np.arange(size) - if needs_masking: - mask = np.isin(groups, seen_groups, assume_unique=True, invert=True) - if mask.any(): - result[..., groups[mask]] = fill_value - return result - - def identity(x: T) -> T: return x @@ -901,6 +883,8 @@ def chunk_reduce( group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func ).astype(dt, copy=False) if engine == "numbagg": + from .aggregate_numbagg import _postprocess_numbagg + result = _postprocess_numbagg( result, func=reduction,