|
| 1 | +--- |
| 2 | +title: 'Improving GroupBy.map with Dask and Xarray' |
| 3 | +date: '2024-11-21' |
| 4 | +authors: |
| 5 | + - name: Patrick Hoefler |
| 6 | + github: phofl |
| 7 | +summary: 'Recent dask improvements make GroupBy.map a lot better!' |
| 8 | +--- |
| 9 | + |
| 10 | +Running large-scale GroupBy-Map patterns with Xarray that are backed by [Dask arrays](https://docs.dask.org/en/stable/array.html?utm_source=xarray-blog) is |
| 11 | +an essential part of a lot of typical geospatial workloads. Detrending is a very common |
| 12 | +operation where this pattern is needed. |
| 13 | + |
| 14 | +In this post, we will explore how and why this caused so many pitfalls for Xarray users in |
| 15 | +the past and how we improved performance and scalability with a few changes to how Dask |
| 16 | +subselects data. |
| 17 | + |
| 18 | +## What is GroupBy.map? |
| 19 | + |
| 20 | +[`GroupBy.map`](https://docs.xarray.dev/en/stable/generated/xarray.core.groupby.DatasetGroupBy.map.html) lets you apply a User Defined Function (UDF) |
| 21 | +that accepts and returns Xarray objects. The UDF will receive an Xarray object (either a Dataset or a DataArray) containing Dask arrays corresponding to one single group. |
| 22 | +[`Groupby.reduce`](https://docs.xarray.dev/en/stable/generated/xarray.core.groupby.DatasetGroupBy.reduce.html) is quite similar |
| 23 | +in that it applies a UDF, but in this case the UDF will receive the underlying Dask arrays, _not_ Xarray objects. |
| 24 | + |
| 25 | +## The Application |
| 26 | + |
| 27 | +Consider a typical workflow where you want to apply a detrending step. You may want to smooth out |
| 28 | +the data by removing the trends over time. This is a common operation in climate science |
| 29 | +and normally looks roughly like this: |
| 30 | + |
| 31 | +```python |
| 32 | +def detrending_step(arr: DataArray) -> DataArray: |
| 33 | + # important: the rolling operation is applied within a group |
| 34 | + return arr - arr.rolling(time=30, min_periods=1).mean() |
| 35 | + |
| 36 | +data.groupby("time.dayofyear").map(detrending_step) |
| 37 | +``` |
| 38 | + |
| 39 | +We are grouping by the day of the year and then are calculating the rolling average over |
| 40 | +30-year windows for a particular day. |
| 41 | + |
| 42 | +Our example will run on a 1 TiB array, 64 years worth of data and the following structure: |
| 43 | + |
| 44 | + |
| 45 | + |
| 46 | +The array isn't overly huge and the chunks are reasonably sized. |
| 47 | + |
| 48 | +## The Problem |
| 49 | + |
| 50 | +The general application seems straightforward. Group by the day of the year and apply a UDF |
| 51 | +to every group. There are a few pitfalls in this application that can make the result of |
| 52 | +this operation unusable. Our array is sorted by time, which means that we have to pick |
| 53 | +entries from many different areas in the array to create a single group (corresponding to a single day of the year). |
| 54 | +Picking the same day of every year is basically a slicing operation with a step size of 365. |
| 55 | + |
| 56 | + |
| 57 | + |
| 58 | +Our example has a year worth of data in a single chunk along the time axis. The general problem |
| 59 | +exists for any workload where you have to access random entries of data. This |
| 60 | +particular access pattern means that we have to pick one value per chunk, which is pretty |
| 61 | +inefficient. The right side shows the individual groups that we are operating on. |
| 62 | + |
| 63 | +One of the main issues with this pattern is that Dask will create a single output chunk per time |
| 64 | +entry, e.g. each group will consist of as many chunks as we have year. |
| 65 | + |
| 66 | +This results in a huge increase in the number of chunks: |
| 67 | + |
| 68 | + |
| 69 | + |
| 70 | +This simple operation increases the number of chunks from 5000 to close to 2 million. Each |
| 71 | +chunk only has a few hundred kilobytes of data. **This is pretty bad!** |
| 72 | + |
| 73 | +Dask computations generally scale along the number of chunks you have. Increasing the chunks by such |
| 74 | +a large factor is catastrophic. Each follow-up operation, as simple as `a-b` will create 2 million |
| 75 | +additional tasks. |
| 76 | + |
| 77 | +The only workaround for users was to rechunk to something more sensible afterward, but it |
| 78 | +still keeps the incredibly expensive indexing operation in the graph. |
| 79 | + |
| 80 | +Note this is the underlying problem that is [solved by flox](https://xarray.dev/blog/flox) for aggregations like `.mean()` |
| 81 | +using parallel-native algorithms to avoid the expense of indexing out each group. |
| 82 | + |
| 83 | +## Improvements to the Data Selection algorithm |
| 84 | + |
| 85 | +The method of how Dask selected the data was objectively pretty bad. |
| 86 | +A rewrite of the underlying algorithm enabled us to achieve a much more robust result. The new |
| 87 | +algorithm is a lot smarter about how to pick values from each individual chunk, but most importantly, |
| 88 | +it will try to preserve the input chunksize as closely as possible. |
| 89 | + |
| 90 | +For our initial example, it will put every group into a single chunk. This means that we will |
| 91 | +end up with the number of chunks along the time axis being equal to the number of groups, i.e. 365. |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | +The algorithm reduces the number of chunks from 2 million to roughly 30 thousand, which is a huge improvement |
| 96 | +and a scale that Dask can easily handle. The graph is now much smaller, and the follow-up operations |
| 97 | +will run a lot faster as well. |
| 98 | + |
| 99 | +This improvement will help every operation that we listed above and make the scale a lot more |
| 100 | +reliably than before. The algorithm is used very widely across Dask and Xarray and thus, influences |
| 101 | +many methods. |
| 102 | + |
| 103 | +## What's next? |
| 104 | + |
| 105 | +Xarray selects one group at a time for `groupby(...).map(...)`, i.e. this requires one operation |
| 106 | +per group. This will hurt scalability if the dataset has a very large number of groups, because |
| 107 | +the computation will create a very expensive graph. There is currently an effort to implement alternative |
| 108 | +APIs that are shuffle-based to circumvent that problem. A current PR is available [here](https://github.com/pydata/xarray/pull/9320). |
| 109 | + |
| 110 | +The fragmentation of the output chunks by indexing is something that will hurt every workflow that is selecting data in a random |
| 111 | +pattern. This also includes: |
| 112 | + |
| 113 | +- `.sel` if you aren't using slices explicitly |
| 114 | +- `.isel` |
| 115 | +- `.sortby` |
| 116 | +- `groupby(...).quantile()` |
| 117 | +- and many more. |
| 118 | + |
| 119 | +We expect all of these workloads to be substantially improved now. |
| 120 | + |
| 121 | +Additionally, [Dask improved a lot of things](https://docs.dask.org/en/stable/changelog.html#v2024-11-1) related to either increasing chunksizes or fragmentation |
| 122 | +of chunks over the cycle of a workload with more improvements to come. This will help a lot of |
| 123 | +users to get better and more reliable performance. |
0 commit comments