|
| 1 | +"""Implementation of the "most common value" regridding method.""" |
| 2 | + |
| 3 | +from itertools import product |
| 4 | +from typing import overload |
| 5 | + |
| 6 | +import flox.xarray |
| 7 | +import numpy as np |
| 8 | +import numpy_groupies as npg |
| 9 | +import pandas as pd |
| 10 | +import xarray as xr |
| 11 | +from flox import Aggregation |
| 12 | + |
| 13 | +from xarray_regrid import utils |
| 14 | + |
| 15 | + |
| 16 | +@overload |
| 17 | +def most_common_wrapper( |
| 18 | + data: xr.DataArray, |
| 19 | + target_ds: xr.Dataset, |
| 20 | + time_dim: str = "", |
| 21 | + max_mem: int | None = None, |
| 22 | +) -> xr.DataArray: |
| 23 | + ... |
| 24 | + |
| 25 | + |
| 26 | +@overload |
| 27 | +def most_common_wrapper( |
| 28 | + data: xr.Dataset, |
| 29 | + target_ds: xr.Dataset, |
| 30 | + time_dim: str = "", |
| 31 | + max_mem: int | None = None, |
| 32 | +) -> xr.Dataset: |
| 33 | + ... |
| 34 | + |
| 35 | + |
| 36 | +def most_common_wrapper( |
| 37 | + data: xr.DataArray | xr.Dataset, |
| 38 | + target_ds: xr.Dataset, |
| 39 | + time_dim: str = "", |
| 40 | + max_mem: int | None = None, |
| 41 | +) -> xr.DataArray | xr.Dataset: |
| 42 | + """Wrapper for the most common regridder, allowing for analyzing larger datasets. |
| 43 | +
|
| 44 | + Args: |
| 45 | + data: Input dataset. |
| 46 | + target_ds: Dataset which coordinates the input dataset should be regrid to. |
| 47 | + time_dim: Name of the time dimension, as the regridders do not regrid over time. |
| 48 | + Defaults to "time". |
| 49 | + max_mem: (Approximate) maximum memory in bytes that the regridding routines can |
| 50 | + use. Note that this is not the total memory consumption and does not include |
| 51 | + the size of the final dataset. |
| 52 | + If this kwargs is used, the regridding will be split up into more manageable |
| 53 | + chunks, and combined for the final dataset. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + xarray.dataset with regridded categorical data. |
| 57 | + """ |
| 58 | + da_name = None |
| 59 | + if isinstance(data, xr.DataArray): |
| 60 | + da_name = "da" if data.name is None else data.name |
| 61 | + data = data.to_dataset(name=da_name) |
| 62 | + |
| 63 | + coords = utils.common_coords(data, target_ds) |
| 64 | + coord_size = [data[coord].size for coord in coords] |
| 65 | + mem_usage = np.prod(coord_size) * np.zeros((1,), dtype=np.int64).itemsize |
| 66 | + |
| 67 | + if max_mem is not None and mem_usage > max_mem: |
| 68 | + result = split_combine_most_common( |
| 69 | + data=data, target_ds=target_ds, time_dim=time_dim, max_mem=max_mem |
| 70 | + ) |
| 71 | + else: |
| 72 | + result = most_common(data=data, target_ds=target_ds, time_dim=time_dim) |
| 73 | + |
| 74 | + if da_name is not None: |
| 75 | + return result[da_name] |
| 76 | + else: |
| 77 | + return result |
| 78 | + |
| 79 | + |
| 80 | +def split_combine_most_common( |
| 81 | + data: xr.Dataset, target_ds: xr.Dataset, time_dim: str, max_mem: int = 1e9 |
| 82 | +) -> xr.Dataset: |
| 83 | + """Use a split-combine strategy to reduce the memory use of the most_common regrid. |
| 84 | +
|
| 85 | + Args: |
| 86 | + data: Input dataset. |
| 87 | + target_ds: Dataset which coordinates the input dataset should be regrid to. |
| 88 | + time_dim: Name of the time dimension, as the regridders do not regrid over time. |
| 89 | + Defaults to "time". |
| 90 | + max_mem: (Approximate) maximum memory in bytes that the regridding routines can |
| 91 | + use. Note that this is not the total memory consumption and does not include |
| 92 | + the size of the final dataset. Defaults to 1e9 (1 GB). |
| 93 | +
|
| 94 | + Returns: |
| 95 | + xarray.dataset with regridded categorical data. |
| 96 | + """ |
| 97 | + coords = utils.common_coords(data, target_ds, remove_coord=time_dim) |
| 98 | + max_datapoints = max_mem // 8 # ~8 bytes per item. |
| 99 | + max_source_coord_size = max_datapoints ** (1 / len(coords)) |
| 100 | + size_ratios = { |
| 101 | + coord: ( |
| 102 | + np.median(np.diff(data[coord].to_numpy(), 1)) |
| 103 | + / np.median(np.diff(target_ds[coord].to_numpy(), 1)) |
| 104 | + ) |
| 105 | + for coord in coords |
| 106 | + } |
| 107 | + max_coord_size = { |
| 108 | + coord: int(size_ratios[coord] * max_source_coord_size) for coord in coords |
| 109 | + } |
| 110 | + |
| 111 | + blocks = { |
| 112 | + coord: np.arange(0, target_ds[coord].size, max_coord_size[coord]) |
| 113 | + for coord in coords |
| 114 | + } |
| 115 | + |
| 116 | + subsets = [] |
| 117 | + for vals in product(*blocks.values()): |
| 118 | + isel = {} |
| 119 | + for coord, val in zip(blocks.keys(), vals, strict=True): |
| 120 | + isel[coord] = slice(val, val + max_coord_size[coord]) |
| 121 | + subsets.append(most_common(data, target_ds.isel(isel), time_dim=time_dim)) |
| 122 | + |
| 123 | + return xr.merge(subsets) |
| 124 | + |
| 125 | + |
| 126 | +def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Dataset: |
| 127 | + """Upsampling of data with a "most common label" approach. |
| 128 | +
|
| 129 | + The implementation includes two steps: |
| 130 | + - "groupby" coordinates |
| 131 | + - select most common label |
| 132 | +
|
| 133 | + We use flox to perform "groupby" multiple dimensions. Here is an example: |
| 134 | + https://flox.readthedocs.io/en/latest/intro.html#histogramming-binning-by-multiple-variables |
| 135 | +
|
| 136 | + To embed our customized function for most common label selection, we need to |
| 137 | + create our `flox.Aggregation`, for instance: |
| 138 | + https://flox.readthedocs.io/en/latest/aggregations.html |
| 139 | +
|
| 140 | + `flox.Aggregation` function works with `numpy_groupies.aggregate_numpy.aggregate |
| 141 | + API. Therefore this function also depends on `numpy_groupies`. For more information, |
| 142 | + check the following example: |
| 143 | + https://flox.readthedocs.io/en/latest/user-stories/custom-aggregations.html |
| 144 | +
|
| 145 | + Note that this module can not handle large dataset if it does not fit into the |
| 146 | + memory. In that case, please first coarsen the land cover data with the `coarse` |
| 147 | + function and then apply the regridder. |
| 148 | +
|
| 149 | + Args: |
| 150 | + data: Input dataset. |
| 151 | + target_ds: Dataset which coordinates the input dataset should be regrid to. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + xarray.dataset with regridded land cover categorical data. |
| 155 | + """ |
| 156 | + coords = utils.common_coords(data, target_ds, remove_coord=time_dim) |
| 157 | + bounds = tuple( |
| 158 | + _construct_intervals(target_ds[coord].to_numpy()) for coord in coords |
| 159 | + ) |
| 160 | + |
| 161 | + # Slice the input data to the bounds of the target dataset |
| 162 | + data = data.sortby(list(coords)) |
| 163 | + for coord in coords: |
| 164 | + coord_res = np.median(np.diff(target_ds[coord].to_numpy(), 1)) |
| 165 | + data = data.sel( |
| 166 | + { |
| 167 | + coord: slice( |
| 168 | + target_ds[coord].min().to_numpy() - coord_res, |
| 169 | + target_ds[coord].max().to_numpy() + coord_res, |
| 170 | + ) |
| 171 | + } |
| 172 | + ) |
| 173 | + |
| 174 | + most_common = Aggregation( |
| 175 | + name="most_common", |
| 176 | + numpy=_custom_grouped_reduction, |
| 177 | + chunk=None, |
| 178 | + combine=None, |
| 179 | + ) |
| 180 | + |
| 181 | + ds_regrid: xr.Dataset = flox.xarray.xarray_reduce( |
| 182 | + data.compute(), |
| 183 | + *coords, |
| 184 | + func=most_common, |
| 185 | + expected_groups=bounds, |
| 186 | + ) |
| 187 | + |
| 188 | + ds_regrid = ds_regrid.rename({f"{coord}_bins": coord for coord in coords}) |
| 189 | + for coord in coords: |
| 190 | + ds_regrid[coord] = target_ds[coord] |
| 191 | + |
| 192 | + return ds_regrid |
| 193 | + |
| 194 | + |
| 195 | +def _construct_intervals(coord: np.ndarray) -> pd.IntervalIndex: |
| 196 | + """Create pandas.intervals with given coordinates.""" |
| 197 | + step_size = np.median(np.diff(coord, n=1)) |
| 198 | + breaks = np.append(coord, coord[-1] + step_size) - step_size / 2 |
| 199 | + |
| 200 | + # Note: closed="both" triggers an `NotImplementedError` |
| 201 | + return pd.IntervalIndex.from_breaks(breaks, closed="left") |
| 202 | + |
| 203 | + |
| 204 | +def _most_common_label(neighbors: np.ndarray) -> np.ndarray: |
| 205 | + """Find the most common label in a neighborhood. |
| 206 | +
|
| 207 | + Note that if more than one labels have the same frequency which is the highest, |
| 208 | + then the first label in the list will be picked. |
| 209 | + """ |
| 210 | + unique_labels, counts = np.unique(neighbors, return_counts=True) |
| 211 | + return unique_labels[np.argmax(counts)] |
| 212 | + |
| 213 | + |
| 214 | +def most_common_chunked(multi_values: np.ndarray, multi_counts: np.ndarray): |
| 215 | + all_values, index = np.unique(multi_values, return_inverse=True) |
| 216 | + all_counts = np.zeros(all_values.size, np.int64) |
| 217 | + np.add.at(all_counts, index, multi_counts.ravel()) # inplace |
| 218 | + return all_values[all_counts.argmax()] |
| 219 | + |
| 220 | + |
| 221 | +def _custom_grouped_reduction( |
| 222 | + group_idx: np.ndarray, |
| 223 | + array: np.ndarray, |
| 224 | + *, |
| 225 | + axis: int = -1, |
| 226 | + size: int | None = None, |
| 227 | + fill_value=None, |
| 228 | + dtype=None, |
| 229 | +) -> np.ndarray: |
| 230 | + """Custom grouped reduction for flox.Aggregation to get most common label. |
| 231 | +
|
| 232 | + Args: |
| 233 | + group_idx : integer codes for group labels (1D) |
| 234 | + array : values to reduce (nD) |
| 235 | + axis : axis of array along which to reduce. |
| 236 | + Requires array.shape[axis] == len(group_idx) |
| 237 | + size : expected number of groups. If none, |
| 238 | + output.shape[-1] == number of uniques in group_idx |
| 239 | + fill_value : fill_value for when number groups in group_idx is less than size |
| 240 | + dtype : dtype of output |
| 241 | +
|
| 242 | + Returns: |
| 243 | + np.ndarray with array.shape[-1] == size, containing a single value per group |
| 244 | + """ |
| 245 | + return npg.aggregate_numpy.aggregate( |
| 246 | + group_idx, |
| 247 | + array, |
| 248 | + func=_most_common_label, |
| 249 | + axis=axis, |
| 250 | + size=size, |
| 251 | + fill_value=fill_value, |
| 252 | + dtype=dtype, |
| 253 | + ) |
0 commit comments