Skip to content

Commit 450c618

Browse files
authored
optimize cohorts yet again (#419)
1 parent ac319cc commit 450c618

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

flox/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,7 @@ def _reduce_blockwise(
14651465
return result
14661466

14671467

1468-
def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
1468+
def _normalize_indexes(ndim: int, flatblocks: Sequence[int], blkshape: tuple[int, ...]) -> tuple:
14691469
"""
14701470
.blocks accessor can only accept one iterable at a time,
14711471
but can handle multiple slices.
@@ -1483,20 +1483,23 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
14831483
if i.ndim == 0:
14841484
normalized.append(i.item())
14851485
else:
1486-
if np.array_equal(i, np.arange(blkshape[ax])):
1486+
if len(i) == blkshape[ax] and np.array_equal(i, np.arange(blkshape[ax])):
14871487
normalized.append(slice(None))
1488-
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
1489-
normalized.append(slice(i[0], i[-1] + 1))
1488+
elif _issorted(i) and np.array_equal(i, np.arange(i[0], i[-1] + 1)):
1489+
start = None if i[0] == 0 else i[0]
1490+
stop = i[-1] + 1
1491+
stop = None if stop == blkshape[ax] else stop
1492+
normalized.append(slice(start, stop))
14901493
else:
14911494
normalized.append(list(i))
1492-
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)
1495+
full_normalized = (slice(None),) * (ndim - len(normalized)) + tuple(normalized)
14931496

14941497
# has no iterables
14951498
noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
14961499
# has all iterables
14971500
alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")}
14981501

1499-
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values())))
1502+
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values()))) # type: ignore[arg-type, var-annotated]
15001503

15011504
full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter))
15021505

@@ -1523,7 +1526,6 @@ def subset_to_blocks(
15231526
-------
15241527
dask.array
15251528
"""
1526-
from dask.array.slicing import normalize_index
15271529
from dask.base import tokenize
15281530

15291531
if blkshape is None:
@@ -1532,10 +1534,9 @@ def subset_to_blocks(
15321534
if chunks_as_array is None:
15331535
chunks_as_array = tuple(np.array(c) for c in array.chunks)
15341536

1535-
index = _normalize_indexes(array, flatblocks, blkshape)
1537+
index = _normalize_indexes(array.ndim, flatblocks, blkshape)
15361538

15371539
# These rest is copied from dask.array.core.py with slight modifications
1538-
index = normalize_index(index, array.numblocks)
15391540
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)
15401541

15411542
name = "groupby-cohort-" + tokenize(array, index)

flox/dask_array_ops.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import builtins
22
import math
3-
from functools import partial
3+
from functools import lru_cache, partial
44
from itertools import product
55
from numbers import Integral
66

@@ -84,14 +84,8 @@ def partial_reduce(
8484
axis: tuple[int, ...],
8585
block_index: int | None = None,
8686
):
87-
numblocks = tuple(len(c) for c in chunks)
88-
ndim = len(numblocks)
89-
parts = [list(partition_all(split_every.get(i, 1), range(n))) for (i, n) in enumerate(numblocks)]
90-
keys = product(*map(range, map(len, parts)))
91-
out_chunks = [
92-
tuple(1 for p in partition_all(split_every[i], c)) if i in split_every else c
93-
for (i, c) in enumerate(chunks)
94-
]
87+
ndim = len(chunks)
88+
keys, parts, out_chunks = get_parts(tuple(split_every.items()), chunks)
9589
for k, p in zip(keys, product(*parts)):
9690
free = {i: j[0] for (i, j) in enumerate(p) if len(j) == 1 and i not in split_every}
9791
dummy = dict(i for i in enumerate(p) if i[0] in split_every)
@@ -101,3 +95,17 @@ def partial_reduce(
10195
k = (*k[:-1], block_index)
10296
dsk[(name,) + k] = (func, g)
10397
return dsk, out_chunks
98+
99+
100+
@lru_cache
101+
def get_parts(split_every_items, chunks):
102+
numblocks = tuple(len(c) for c in chunks)
103+
split_every = dict(split_every_items)
104+
105+
parts = [list(partition_all(split_every.get(i, 1), range(n))) for (i, n) in enumerate(numblocks)]
106+
keys = tuple(product(*map(range, map(len, parts))))
107+
out_chunks = tuple(
108+
tuple(1 for p in partition_all(split_every[i], c)) if i in split_every else c
109+
for (i, c) in enumerate(chunks)
110+
)
111+
return keys, parts, out_chunks

tests/test_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from collections.abc import Callable
77
from functools import partial, reduce
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99
from unittest.mock import MagicMock, patch
1010

1111
import numpy as np
@@ -1538,7 +1538,7 @@ def test_normalize_block_indexing_1d(flatblocks, expected):
15381538
nblocks = 5
15391539
array = dask.array.ones((nblocks,), chunks=(1,))
15401540
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
1541-
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
1541+
actual = _normalize_indexes(array.ndim, flatblocks, array.blocks.shape)
15421542
assert_equal_tuple(expected, actual)
15431543

15441544

@@ -1550,17 +1550,17 @@ def test_normalize_block_indexing_1d(flatblocks, expected):
15501550
((1, 2, 3), (0, slice(1, 4))),
15511551
((1, 3), (0, [1, 3])),
15521552
((0, 1, 3), (0, [0, 1, 3])),
1553-
(tuple(range(10)), (slice(0, 2), slice(None))),
1554-
((0, 1, 3, 5, 6, 8), (slice(0, 2), [0, 1, 3])),
1553+
(tuple(range(10)), (slice(None, 2), slice(None))),
1554+
((0, 1, 3, 5, 6, 8), (slice(None, 2), [0, 1, 3])),
15551555
((0, 3, 4, 5, 6, 8, 24), np.ix_([0, 1, 4], [0, 1, 3, 4])),
15561556
),
15571557
)
1558-
def test_normalize_block_indexing_2d(flatblocks, expected):
1558+
def test_normalize_block_indexing_2d(flatblocks: tuple[int, ...], expected: tuple[Any, ...]) -> None:
15591559
nblocks = 5
15601560
ndim = 2
15611561
array = dask.array.ones((nblocks,) * ndim, chunks=(1,) * ndim)
15621562
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
1563-
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
1563+
actual = _normalize_indexes(array.ndim, flatblocks, array.blocks.shape)
15641564
assert_equal_tuple(expected, actual)
15651565

15661566

0 commit comments

Comments
 (0)