Skip to content

Commit 3c587fb

Browse files
phofldcherianscharlottej13andersy005
authored
Add blogpost for Detrending operations with Dask and Xarray (#728)
* Add post * Update * Suggestions * edits * Add front matter * Fixup * Remove ticks * nits * add social share card --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: scharlottej13 <[email protected]> Co-authored-by: Anderson Banihirwe <[email protected]>
1 parent 12f2a5e commit 3c587fb

File tree

6 files changed

+123
-0
lines changed

6 files changed

+123
-0
lines changed

public/cards/dask-detrending.png

134 KB
Loading
219 KB
Loading
104 KB
Loading
102 KB
Loading
103 KB
Loading

src/posts/dask-detrending/index.md

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
---
2+
title: 'Improving GroupBy.map with Dask and Xarray'
3+
date: '2024-11-21'
4+
authors:
5+
- name: Patrick Hoefler
6+
github: phofl
7+
summary: 'Recent dask improvements make GroupBy.map a lot better!'
8+
---
9+
10+
Running large-scale GroupBy-Map patterns with Xarray that are backed by [Dask arrays](https://docs.dask.org/en/stable/array.html?utm_source=xarray-blog) is
11+
an essential part of a lot of typical geospatial workloads. Detrending is a very common
12+
operation where this pattern is needed.
13+
14+
In this post, we will explore how and why this caused so many pitfalls for Xarray users in
15+
the past and how we improved performance and scalability with a few changes to how Dask
16+
subselects data.
17+
18+
## What is GroupBy.map?
19+
20+
[`GroupBy.map`](https://docs.xarray.dev/en/stable/generated/xarray.core.groupby.DatasetGroupBy.map.html) lets you apply a User Defined Function (UDF)
21+
that accepts and returns Xarray objects. The UDF will receive an Xarray object (either a Dataset or a DataArray) containing Dask arrays corresponding to one single group.
22+
[`Groupby.reduce`](https://docs.xarray.dev/en/stable/generated/xarray.core.groupby.DatasetGroupBy.reduce.html) is quite similar
23+
in that it applies a UDF, but in this case the UDF will receive the underlying Dask arrays, _not_ Xarray objects.
24+
25+
## The Application
26+
27+
Consider a typical workflow where you want to apply a detrending step. You may want to smooth out
28+
the data by removing the trends over time. This is a common operation in climate science
29+
and normally looks roughly like this:
30+
31+
```python
32+
def detrending_step(arr: DataArray) -> DataArray:
33+
# important: the rolling operation is applied within a group
34+
return arr - arr.rolling(time=30, min_periods=1).mean()
35+
36+
data.groupby("time.dayofyear").map(detrending_step)
37+
```
38+
39+
We are grouping by the day of the year and then are calculating the rolling average over
40+
30-year windows for a particular day.
41+
42+
Our example will run on a 1 TiB array, 64 years worth of data and the following structure:
43+
44+
![](/posts/dask-detrending/input-array.png)
45+
46+
The array isn't overly huge and the chunks are reasonably sized.
47+
48+
## The Problem
49+
50+
The general application seems straightforward. Group by the day of the year and apply a UDF
51+
to every group. There are a few pitfalls in this application that can make the result of
52+
this operation unusable. Our array is sorted by time, which means that we have to pick
53+
entries from many different areas in the array to create a single group (corresponding to a single day of the year).
54+
Picking the same day of every year is basically a slicing operation with a step size of 365.
55+
56+
![](/posts/dask-detrending/indexing-data-selection.png 'Data Selection Pattern')
57+
58+
Our example has a year worth of data in a single chunk along the time axis. The general problem
59+
exists for any workload where you have to access random entries of data. This
60+
particular access pattern means that we have to pick one value per chunk, which is pretty
61+
inefficient. The right side shows the individual groups that we are operating on.
62+
63+
One of the main issues with this pattern is that Dask will create a single output chunk per time
64+
entry, e.g. each group will consist of as many chunks as we have year.
65+
66+
This results in a huge increase in the number of chunks:
67+
68+
![](/posts/dask-detrending/output-array-old.png)
69+
70+
This simple operation increases the number of chunks from 5000 to close to 2 million. Each
71+
chunk only has a few hundred kilobytes of data. **This is pretty bad!**
72+
73+
Dask computations generally scale along the number of chunks you have. Increasing the chunks by such
74+
a large factor is catastrophic. Each follow-up operation, as simple as `a-b` will create 2 million
75+
additional tasks.
76+
77+
The only workaround for users was to rechunk to something more sensible afterward, but it
78+
still keeps the incredibly expensive indexing operation in the graph.
79+
80+
Note this is the underlying problem that is [solved by flox](https://xarray.dev/blog/flox) for aggregations like `.mean()`
81+
using parallel-native algorithms to avoid the expense of indexing out each group.
82+
83+
## Improvements to the Data Selection algorithm
84+
85+
The method of how Dask selected the data was objectively pretty bad.
86+
A rewrite of the underlying algorithm enabled us to achieve a much more robust result. The new
87+
algorithm is a lot smarter about how to pick values from each individual chunk, but most importantly,
88+
it will try to preserve the input chunksize as closely as possible.
89+
90+
For our initial example, it will put every group into a single chunk. This means that we will
91+
end up with the number of chunks along the time axis being equal to the number of groups, i.e. 365.
92+
93+
![](/posts/dask-detrending/output-array-new.png)
94+
95+
The algorithm reduces the number of chunks from 2 million to roughly 30 thousand, which is a huge improvement
96+
and a scale that Dask can easily handle. The graph is now much smaller, and the follow-up operations
97+
will run a lot faster as well.
98+
99+
This improvement will help every operation that we listed above and make the scale a lot more
100+
reliably than before. The algorithm is used very widely across Dask and Xarray and thus, influences
101+
many methods.
102+
103+
## What's next?
104+
105+
Xarray selects one group at a time for `groupby(...).map(...)`, i.e. this requires one operation
106+
per group. This will hurt scalability if the dataset has a very large number of groups, because
107+
the computation will create a very expensive graph. There is currently an effort to implement alternative
108+
APIs that are shuffle-based to circumvent that problem. A current PR is available [here](https://github.com/pydata/xarray/pull/9320).
109+
110+
The fragmentation of the output chunks by indexing is something that will hurt every workflow that is selecting data in a random
111+
pattern. This also includes:
112+
113+
- `.sel` if you aren't using slices explicitly
114+
- `.isel`
115+
- `.sortby`
116+
- `groupby(...).quantile()`
117+
- and many more.
118+
119+
We expect all of these workloads to be substantially improved now.
120+
121+
Additionally, [Dask improved a lot of things](https://docs.dask.org/en/stable/changelog.html#v2024-11-1) related to either increasing chunksizes or fragmentation
122+
of chunks over the cycle of a workload with more improvements to come. This will help a lot of
123+
users to get better and more reliable performance.

0 commit comments

Comments
 (0)