Skip to content

Commit 61bea34

Browse files
committed
Implement most common regridder
1 parent 3b3f5fe commit 61bea34

File tree

2 files changed

+265
-0
lines changed

2 files changed

+265
-0
lines changed

src/xarray_regrid/most_common.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
"""Implementation of the "most common value" regridding method."""
2+
3+
from itertools import product
4+
from typing import overload
5+
6+
import flox.xarray
7+
import numpy as np
8+
import numpy_groupies as npg
9+
import pandas as pd
10+
import xarray as xr
11+
from flox import Aggregation
12+
13+
from xarray_regrid import utils
14+
15+
16+
@overload
17+
def most_common_wrapper(
18+
data: xr.DataArray,
19+
target_ds: xr.Dataset,
20+
time_dim: str = "",
21+
max_mem: int | None = None,
22+
) -> xr.DataArray:
23+
...
24+
25+
26+
@overload
27+
def most_common_wrapper(
28+
data: xr.Dataset,
29+
target_ds: xr.Dataset,
30+
time_dim: str = "",
31+
max_mem: int | None = None,
32+
) -> xr.Dataset:
33+
...
34+
35+
36+
def most_common_wrapper(
37+
data: xr.DataArray | xr.Dataset,
38+
target_ds: xr.Dataset,
39+
time_dim: str = "",
40+
max_mem: int | None = None,
41+
) -> xr.DataArray | xr.Dataset:
42+
"""Wrapper for the most common regridder, allowing for analyzing larger datasets.
43+
44+
Args:
45+
data: Input dataset.
46+
target_ds: Dataset which coordinates the input dataset should be regrid to.
47+
time_dim: Name of the time dimension, as the regridders do not regrid over time.
48+
Defaults to "time".
49+
max_mem: (Approximate) maximum memory in bytes that the regridding routines can
50+
use. Note that this is not the total memory consumption and does not include
51+
the size of the final dataset.
52+
If this kwargs is used, the regridding will be split up into more manageable
53+
chunks, and combined for the final dataset.
54+
55+
Returns:
56+
xarray.dataset with regridded categorical data.
57+
"""
58+
da_name = None
59+
if isinstance(data, xr.DataArray):
60+
da_name = "da" if data.name is None else data.name
61+
data = data.to_dataset(name=da_name)
62+
63+
coords = utils.common_coords(data, target_ds)
64+
coord_size = [data[coord].size for coord in coords]
65+
mem_usage = np.prod(coord_size) * np.zeros((1,), dtype=np.int64).itemsize
66+
67+
if max_mem is not None and mem_usage > max_mem:
68+
result = split_combine_most_common(
69+
data=data, target_ds=target_ds, time_dim=time_dim, max_mem=max_mem
70+
)
71+
else:
72+
result = most_common(data=data, target_ds=target_ds, time_dim=time_dim)
73+
74+
if da_name is not None:
75+
return result[da_name]
76+
else:
77+
return result
78+
79+
80+
def split_combine_most_common(
81+
data: xr.Dataset, target_ds: xr.Dataset, time_dim: str, max_mem: int = 1e9
82+
) -> xr.Dataset:
83+
"""Use a split-combine strategy to reduce the memory use of the most_common regrid.
84+
85+
Args:
86+
data: Input dataset.
87+
target_ds: Dataset which coordinates the input dataset should be regrid to.
88+
time_dim: Name of the time dimension, as the regridders do not regrid over time.
89+
Defaults to "time".
90+
max_mem: (Approximate) maximum memory in bytes that the regridding routines can
91+
use. Note that this is not the total memory consumption and does not include
92+
the size of the final dataset. Defaults to 1e9 (1 GB).
93+
94+
Returns:
95+
xarray.dataset with regridded categorical data.
96+
"""
97+
coords = utils.common_coords(data, target_ds, remove_coord=time_dim)
98+
max_datapoints = max_mem // 8 # ~8 bytes per item.
99+
max_source_coord_size = max_datapoints ** (1 / len(coords))
100+
size_ratios = {
101+
coord: (
102+
np.median(np.diff(data[coord].to_numpy(), 1))
103+
/ np.median(np.diff(target_ds[coord].to_numpy(), 1))
104+
)
105+
for coord in coords
106+
}
107+
max_coord_size = {
108+
coord: int(size_ratios[coord] * max_source_coord_size) for coord in coords
109+
}
110+
111+
blocks = {
112+
coord: np.arange(0, target_ds[coord].size, max_coord_size[coord])
113+
for coord in coords
114+
}
115+
116+
subsets = []
117+
for vals in product(*blocks.values()):
118+
isel = {}
119+
for coord, val in zip(blocks.keys(), vals, strict=True):
120+
isel[coord] = slice(val, val + max_coord_size[coord])
121+
subsets.append(most_common(data, target_ds.isel(isel), time_dim=time_dim))
122+
123+
return xr.merge(subsets)
124+
125+
126+
def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Dataset:
127+
"""Upsampling of data with a "most common label" approach.
128+
129+
The implementation includes two steps:
130+
- "groupby" coordinates
131+
- select most common label
132+
133+
We use flox to perform "groupby" multiple dimensions. Here is an example:
134+
https://flox.readthedocs.io/en/latest/intro.html#histogramming-binning-by-multiple-variables
135+
136+
To embed our customized function for most common label selection, we need to
137+
create our `flox.Aggregation`, for instance:
138+
https://flox.readthedocs.io/en/latest/aggregations.html
139+
140+
`flox.Aggregation` function works with `numpy_groupies.aggregate_numpy.aggregate
141+
API. Therefore this function also depends on `numpy_groupies`. For more information,
142+
check the following example:
143+
https://flox.readthedocs.io/en/latest/user-stories/custom-aggregations.html
144+
145+
Note that this module can not handle large dataset if it does not fit into the
146+
memory. In that case, please first coarsen the land cover data with the `coarse`
147+
function and then apply the regridder.
148+
149+
Args:
150+
data: Input dataset.
151+
target_ds: Dataset which coordinates the input dataset should be regrid to.
152+
153+
Returns:
154+
xarray.dataset with regridded land cover categorical data.
155+
"""
156+
coords = utils.common_coords(data, target_ds, remove_coord=time_dim)
157+
bounds = tuple(
158+
_construct_intervals(target_ds[coord].to_numpy()) for coord in coords
159+
)
160+
161+
# Slice the input data to the bounds of the target dataset
162+
data = data.sortby(list(coords))
163+
for coord in coords:
164+
coord_res = np.median(np.diff(target_ds[coord].to_numpy(), 1))
165+
data = data.sel(
166+
{
167+
coord: slice(
168+
target_ds[coord].min().to_numpy() - coord_res,
169+
target_ds[coord].max().to_numpy() + coord_res,
170+
)
171+
}
172+
)
173+
174+
most_common = Aggregation(
175+
name="most_common",
176+
numpy=_custom_grouped_reduction,
177+
chunk=None,
178+
combine=None,
179+
)
180+
181+
ds_regrid: xr.Dataset = flox.xarray.xarray_reduce(
182+
data.compute(),
183+
*coords,
184+
func=most_common,
185+
expected_groups=bounds,
186+
)
187+
188+
ds_regrid = ds_regrid.rename({f"{coord}_bins": coord for coord in coords})
189+
for coord in coords:
190+
ds_regrid[coord] = target_ds[coord]
191+
192+
return ds_regrid
193+
194+
195+
def _construct_intervals(coord: np.ndarray) -> pd.IntervalIndex:
196+
"""Create pandas.intervals with given coordinates."""
197+
step_size = np.median(np.diff(coord, n=1))
198+
breaks = np.append(coord, coord[-1] + step_size) - step_size / 2
199+
200+
# Note: closed="both" triggers an `NotImplementedError`
201+
return pd.IntervalIndex.from_breaks(breaks, closed="left")
202+
203+
204+
def _most_common_label(neighbors: np.ndarray) -> np.ndarray:
205+
"""Find the most common label in a neighborhood.
206+
207+
Note that if more than one labels have the same frequency which is the highest,
208+
then the first label in the list will be picked.
209+
"""
210+
unique_labels, counts = np.unique(neighbors, return_counts=True)
211+
return unique_labels[np.argmax(counts)]
212+
213+
214+
def most_common_chunked(multi_values: np.ndarray, multi_counts: np.ndarray):
215+
all_values, index = np.unique(multi_values, return_inverse=True)
216+
all_counts = np.zeros(all_values.size, np.int64)
217+
np.add.at(all_counts, index, multi_counts.ravel()) # inplace
218+
return all_values[all_counts.argmax()]
219+
220+
221+
def _custom_grouped_reduction(
222+
group_idx: np.ndarray,
223+
array: np.ndarray,
224+
*,
225+
axis: int = -1,
226+
size: int | None = None,
227+
fill_value=None,
228+
dtype=None,
229+
) -> np.ndarray:
230+
"""Custom grouped reduction for flox.Aggregation to get most common label.
231+
232+
Args:
233+
group_idx : integer codes for group labels (1D)
234+
array : values to reduce (nD)
235+
axis : axis of array along which to reduce.
236+
Requires array.shape[axis] == len(group_idx)
237+
size : expected number of groups. If none,
238+
output.shape[-1] == number of uniques in group_idx
239+
fill_value : fill_value for when number groups in group_idx is less than size
240+
dtype : dtype of output
241+
242+
Returns:
243+
np.ndarray with array.shape[-1] == size, containing a single value per group
244+
"""
245+
return npg.aggregate_numpy.aggregate(
246+
group_idx,
247+
array,
248+
func=_most_common_label,
249+
axis=axis,
250+
size=size,
251+
fill_value=fill_value,
252+
dtype=dtype,
253+
)

src/xarray_regrid/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,15 @@ def create_dot_dataarray(
150150
f"target_{coord}": target_coords,
151151
},
152152
)
153+
154+
155+
def common_coords(
156+
data1: xr.DataArray | xr.Dataset,
157+
data2: xr.DataArray | xr.Dataset,
158+
remove_coord: str | None = None,
159+
) -> set[str]:
160+
"""Return a set of coords which two dataset/arrays have in common."""
161+
coords = set(data1.coords).intersection(set(data2.coords))
162+
if remove_coord in coords:
163+
coords.remove(remove_coord)
164+
return coords

0 commit comments

Comments
 (0)