Skip to content

Commit 010f01d

Browse files
committed
more edits
1 parent 105c22d commit 010f01d

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

src/posts/flox-smart/index.md

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ summary: 'flox adds heuristics for automatically choosing an appropriate strateg
1010

1111
## TL;DR
1212

13-
`flox>=0.9` adds heuristics for automatically choosing an appropriate strategy with dask arrays! Here I describe how.
13+
`flox>=0.9` now automatically optimizes GroupBy reductions with dask arrays to reduce memory usage and increase parallelism! Here I describe how.
1414

1515
## What is flox?
1616

1717
[`flox` implements](https://flox.readthedocs.io/) grouped reductions for chunked array types like [cubed](https://cubed-dev.github.io/cubed/) and [dask](https://docs.dask.org/en/stable/array.html) using tree reductions.
1818
Tree reductions ([example](https://people.csail.mit.edu/xchen/gpu-programming/Lecture11-reduction.pdf)) are a parallel-friendly way of computing common reduction operations like `sum`, `mean` etc.
19+
Briefly, one computes the reduction for a subset of the array $N$ chunks at a time in parallel, then combines those results together again $N$ chunks at a time, until we have the final result.
20+
1921
Without flox, Xarray effectively shuffles — sorts the data to extract all values in a single group — and then runs the reduction group-by-group.
2022
Depending on data layout or "chunking" this shuffle can be quite expensive.
2123
![shuffle](https://flox.readthedocs.io/en/latest/_images/new-split-apply-combine-annotated.svg)
@@ -25,13 +27,6 @@ Notice how much cleaner the graph is in this image:
2527
![map-reduce](https://flox.readthedocs.io/en/latest/_images/new-map-reduce-reindex-True-annotated.svg)
2628
See our [previous blog post](https://xarray.dev/blog/flox) for more.
2729

28-
Two key realizations influenced the development of flox:
29-
30-
1. Array workloads frequently group by a relatively small in-memory array. Quite frequently those arrays have patterns to their values e.g. `"time.month"` is exactly periodic, `"time.dayofyear"` is approximately periodic (depending on calendar), `"time.year"` is commonly a monotonic increasing array.
31-
2. Chunk sizes (or "partition sizes") for arrays can be quite small along the core-dimension of an operation. This is an important difference between arrays and dataframes!
32-
33-
These two properties are particularly relevant for "climatology" calculations (e.g. `groupby("time.month").mean()`) — a common Xarray workload in the Earth Sciences.
34-
3530
## Tree reductions can be catastrophically bad
3631

3732
Consider `ds.groupby("time.year").mean()`, or the equivalent `ds.resample(time="Y").mean()` for a 100 year long dataset of monthly averages with chunk size of **1** (or **4**) along the time dimension.
@@ -41,11 +36,17 @@ The small chunk size along time is offset by much larger chunk sizes along the o
4136
A naive tree reduction would accumulate all averaged values into a single output chunk of size 100 — one value per year for 100 years.
4237
Depending on the chunking of the input dataset, this may overload the final worker's memory and fail catastrophically.
4338
More importantly, there is a lot of wasteful communication — computing on the last year of data is completely independent of computing on the first year of the data, and there is no reason the results for the two years need to reside in the same output chunk.
44-
This issue does not arise for regular reductions where the final result depends on the values in all chunks, and all data along the reduced axes are reduced down to one final value.
39+
This issue does _not_ arise for regular reductions where the final result depends on the values in all chunks, and all data along the reduced axes are reduced down to one final value.
4540

4641
## Avoiding catastrophe
4742

4843
Thus `flox` quickly grew two new modes of computing the groupby reduction.
44+
Two key realizations influenced that development:
45+
46+
1. Array workloads frequently group by a relatively small in-memory array. Quite frequently those arrays have patterns to their values e.g. `"time.month"` is exactly periodic, `"time.dayofyear"` is approximately periodic (depending on calendar), `"time.year"` is commonly a monotonic increasing array.
47+
2. Chunk sizes (or "partition sizes") for arrays can be quite small along the core-dimension of an operation. This is an important difference between arrays and dataframes!
48+
49+
These two properties are particularly relevant for "climatology" calculations (e.g. `groupby("time.month").mean()`) — a common Xarray workload in the Earth Sciences.
4950

5051
First, `method="blockwise"` which applies the grouped-reduction in a blockwise fashion.
5152
This is great for `resample(time="Y").mean()` where we group by `"time.year"`, which is a monotonic increasing array.
@@ -61,7 +62,7 @@ Here is a schematic illustration where each month is represented by a different
6162
This means that we can run the tree reduction for each cohort (three cohorts in total: `JFMA | MJJA | SOND`) independently and expose more parallelism.
6263
Doing so can significantly reduce compute times and in particular memory required for the computation.
6364

64-
Importantly if there isn't much separation of groups into cohorts; example, the groups are randomly distributed, then it's hard to do better than the standard `method="map-reduce"`.
65+
If there isn't much separation of groups into cohorts, like when groups are randomly distributed across chunks, then it's hard to do better than the standard `method="map-reduce"`.
6566

6667
## Choosing a strategy is hard, and harder to teach.
6768

@@ -101,7 +102,7 @@ Importantly, we do _not_ want to be dependent on detecting exact patterns, and p
101102

102103
After a fun exploration involving such fun ideas as [locality-sensitive hashing](http://ekzhu.com/datasketch/lshensemble.html), and [all-pair set similarity search](https://www.cse.unsw.edu.au/~lxue/WWW08.pdf), I settled on the following algorithm.
103104

104-
I use set _containment_, or a "normalized intersection", to determine the similarity the sets of chunks occupied by two different groups (`Q` and `X`).
105+
I use set _containment_, or a "normalized intersection", to determine the similarity between the sets of chunks occupied by two different groups (`Q` and `X`).
105106

106107
```
107108
C = |Q ∩ X| / |Q| ≤ 1; (∩ is set intersection)
@@ -114,23 +115,34 @@ The steps are as follows:
114115
1. First determine which labels are present in each chunk. The distribution of labels across chunks
115116
is represented internally as a 2D boolean sparse array `S[chunks, labels]`. `S[i, j] = 1` when
116117
label `j` is present in chunk `i`.
117-
1. Now we can quickly determine a number of special cases:
118-
1. Use `"blockwise"` when every group is contained to one block each.
119-
1. Use `"cohorts"` when every chunk only has a single group, but that group might extend across multiple chunks
120-
1. [and more](https://github.com/xarray-contrib/flox/blob/e6159a657c55fa4aeb31bcbcecb341a4849da9fe/flox/core.py#L408-L426)
121118
1. Now invert `S` to compute an initial set of cohorts whose groups are in the same exact chunks (this is another groupby!).
122119
Later we will want to merge together the detected cohorts when they occupy _approximately_ the same chunks, using the containment metric.
120+
1. Now we can quickly determine a number of special cases and exit early:
121+
1. Use `"blockwise"` when every group is contained to one block each.
122+
1. Use `"cohorts"` when
123+
1. every chunk only has a single group, but that group might extend across multiple chunks; and
124+
1. existing cohorts don't overlap at all.
125+
1. [and more](https://github.com/xarray-contrib/flox/blob/e6159a657c55fa4aeb31bcbcecb341a4849da9fe/flox/core.py#L408-L426)
126+
127+
If we reach here, then we want to merge together any detected cohorts that substantially overlap with each other.
128+
123129
1. For that we first quickly compute containment for all groups `i` against all other groups `j` as `C = S.T @ S / number_chunks_per_group`.
124130
1. To choose between `"map-reduce"` and `"cohorts"`, we need a summary measure of the degree to which the labels overlap with
125-
each other. We can use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
126-
We use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`. When sparsity is relatively high, we use `"map-reduce"`, otherwise we use `"cohorts"`.
127-
1. If the sparsity is high enough, we merge together similar cohorts using a for-loop.
128-
1. Finally we execute one tree-reduction per cohort and concatenate the results.
131+
each other. We use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
132+
When sparsity is relatively high, we use `"map-reduce"`, otherwise we use `"cohorts"`.
133+
1. If the sparsity is low enough, we merge together similar cohorts using a for-loop.
129134

130135
For more detail [see the docs](https://flox.readthedocs.io/en/latest/implementation.html#heuristics) or [the code](https://github.com/xarray-contrib/flox/blob/e6159a657c55fa4aeb31bcbcecb341a4849da9fe/flox/core.py#L336).
131136
Suggestions and improvements are very welcome!
132137

133-
Here is `C` for a range of chunk sizes from 1 to 12, for computing `groupby("time.month")` of a monthly mean dataset, [the title on each image is (chunk size, sparsity)].
138+
Here is containment `C[i, j]` for a range of chunk sizes from 1 to 12, for an input array with 12 monthly mean time steps,
139+
for computing `groupby("time.month")` of a monthly mean dataset.
140+
These are colored so that light yellow is $C=0$, and dark purple is $C=1$.
141+
The title on each image is (chunk size, sparsity).
142+
`C[i,j] = 1` when the chunks occupied by group `i` perfectly overlaps with those occupied by group `j` (so the diagonal elements
143+
are always 1).
144+
When the chunksize _is_ a divisor of the period 12, $C$ is a [block diagonal](https://en.wikipedia.org/wiki/Block_matrix) matrix.
145+
When the chunksize _is not_ a divisor of the period 12, $C$ is much less sparse in comparison.
134146
![flox sparsity image](https://flox.readthedocs.io/en/latest/_images/containment.png)
135147

136148
Given the above `C`, flox will choose:
@@ -146,8 +158,10 @@ But we have not tried with bigger problems (example: GroupBy(100,000 watersheds)
146158

147159
## What's next?
148160

149-
flox' ability to do such inferences relies entirely on the input chunking, a big knob.
150-
A recent Xarray feature makes such rechunking a lot easier for time grouping:
161+
flox' ability to do cleanly infer an optimal strategy relies entirely on the input chunking making such optimization possible.
162+
This is a big knob.
163+
A brand new [Xarray feature](https://docs.xarray.dev/en/stable/user-guide/groupby.html#grouper-objects) does make such rechunking
164+
a lot easier for time grouping in particular:
151165

152166
```python
153167
from xarray.groupers import TimeResampler
@@ -156,13 +170,14 @@ rechunked = ds.chunk(time=TimeResampler("YE"))
156170
```
157171

158172
will rechunk so that a year of data is in a single chunk.
159-
160173
Even so, it would be nice to automatically rechunk to minimize number of cohorts detected, or to a perfectly blockwise application when that's cheap.
174+
161175
A challenge here is that we have lost _context_ when moving from Xarray to flox.
162176
The string `"time.month"` tells Xarray that I am grouping a perfectly periodic array with period 12; similarly
163177
the _string_ `"time.dayofyear"` tells Xarray that I am grouping by a (quasi-)periodic array with period 365, and that group `366` may occur occasionally (depending on calendar).
164178
But Xarray passes flox an array of integer group labels `[1, 2, 3, 4, 5, ..., 1, 2, 3, 4, 5]`.
165179
It's hard to infer the context from that!
180+
Though one approach might frame the problem as: what rechunking would transform `C` to a block diagonal matrix.
166181
_[Get in touch](https://github.com/xarray-contrib/flox/issues) if you have ideas for how to do this inference._
167182

168183
One way to preserve context may be be to have Xarray's new Grouper objects report ["preferred chunks"](https://github.com/pydata/xarray/blob/main/design_notes/grouper_objects.md#the-preferred_chunks-method-) for a particular grouping.

0 commit comments

Comments
 (0)