Skip to content

Commit 9d71962

Browse files
authored
Merge pull request #45 from xarray-contrib/boundary-padding
Spherical padding and faster tests
2 parents 8bdc636 + 1f2e999 commit 9d71962

File tree

5 files changed

+461
-125
lines changed

5 files changed

+461
-125
lines changed

src/xarray_regrid/methods/conservative.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def conservative_regrid(
7272

7373
# Make sure the regridding coordinates are sorted
7474
coord_names = [coord for coord in target_ds.coords if coord in data.coords]
75-
target_ds_sorted = target_ds.sortby(coord_names)
76-
data = data.sortby(list(coord_names))
75+
target_ds_sorted = xr.Dataset(coords=target_ds.coords)
76+
for coord_name in coord_names:
77+
target_ds_sorted = utils.ensure_monotonic(target_ds_sorted, coord_name)
78+
data = utils.ensure_monotonic(data, coord_name)
7779
coords = {name: target_ds_sorted[name] for name in coord_names}
7880

7981
regridded_data = utils.call_on_dataset(
@@ -122,15 +124,13 @@ def conservative_regrid_dataset(
122124
weights = apply_spherical_correction(weights, latitude_coord)
123125

124126
for array in data_vars.keys():
125-
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
126127
if coord in data_vars[array].dims:
127128
data_vars[array], valid_fracs[array] = apply_weights(
128129
da=data_vars[array],
129130
weights=weights,
130131
coord=coord,
131132
valid_frac=valid_fracs[array],
132133
skipna=skipna,
133-
non_grid_dims=non_grid_dims,
134134
)
135135
# Mask out any regridded points outside the original domain
136136
data_vars[array] = data_vars[array].where(covered_grid)
@@ -161,16 +161,13 @@ def apply_weights(
161161
coord: Hashable,
162162
valid_frac: xr.DataArray,
163163
skipna: bool,
164-
non_grid_dims: list[Hashable],
165164
) -> tuple[xr.DataArray, xr.DataArray]:
166165
"""Apply the weights to convert data to the new coordinates."""
167166
coord_map = {f"target_{coord}": coord}
168167
weights_norm = weights.copy()
169168

170169
if skipna:
171170
notnull = da.notnull()
172-
if non_grid_dims:
173-
notnull = notnull.any(non_grid_dims)
174171
# Renormalize the weights along this dim by the accumulated valid_frac
175172
# along previous dimensions
176173
if valid_frac.name != EMPTY_DA_NAME:

src/xarray_regrid/regrid.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import xarray as xr
22

33
from xarray_regrid.methods import conservative, interp, most_common
4+
from xarray_regrid.utils import format_for_regrid
45

56

67
@xr.register_dataarray_accessor("regrid")
@@ -34,7 +35,8 @@ def linear(
3435
Data regridded to the target dataset coordinates.
3536
"""
3637
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
37-
return interp.interp_regrid(self._obj, ds_target_grid, "linear")
38+
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
39+
return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")
3840

3941
def nearest(
4042
self,
@@ -51,14 +53,14 @@ def nearest(
5153
Data regridded to the target dataset coordinates.
5254
"""
5355
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
54-
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")
56+
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
57+
return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")
5558

5659
def cubic(
5760
self,
5861
ds_target_grid: xr.Dataset,
5962
time_dim: str = "time",
6063
) -> xr.DataArray | xr.Dataset:
61-
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
6264
"""Regrid to the coords of the target dataset with cubic interpolation.
6365
6466
Args:
@@ -68,7 +70,9 @@ def cubic(
6870
Returns:
6971
Data regridded to the target dataset coordinates.
7072
"""
71-
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")
73+
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
74+
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
75+
return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")
7276

7377
def conservative(
7478
self,
@@ -88,6 +92,9 @@ def conservative(
8892
time_dim: The name of the time dimension/coordinate.
8993
skipna: If True, enable handling for NaN values. This adds some overhead,
9094
so can be disabled for optimal performance on data without any NaNs.
95+
With `skipna=True, chunking is recommended in the non-grid dimensions,
96+
otherwise the intermediate arrays that track the fraction of valid data
97+
can become very large and consume excessive memory.
9198
Warning: with `skipna=False`, isolated NaNs will propagate throughout
9299
the dataset due to the sequential regridding scheme over each dimension.
93100
nan_threshold: Threshold value that will retain any output points
@@ -104,8 +111,9 @@ def conservative(
104111
raise ValueError(msg)
105112

106113
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
114+
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
107115
return conservative.conservative_regrid(
108-
self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold
116+
ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold
109117
)
110118

111119
def most_common(
@@ -134,8 +142,9 @@ def most_common(
134142
Regridded data.
135143
"""
136144
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
145+
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
137146
return most_common.most_common_wrapper(
138-
self._obj, ds_target_grid, time_dim, max_mem
147+
ds_formatted, ds_target_grid, time_dim, max_mem
139148
)
140149

141150

src/xarray_regrid/utils.py

Lines changed: 198 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Hashable
22
from dataclasses import dataclass
3-
from typing import Any, overload
3+
from typing import Any, TypedDict, overload
44

55
import numpy as np
66
import pandas as pd
@@ -10,6 +10,11 @@
1010
class InvalidBoundsError(Exception): ...
1111

1212

13+
class CoordHandler(TypedDict):
14+
names: list[str]
15+
func: Callable
16+
17+
1318
@dataclass
1419
class Grid:
1520
"""Object storing grid information."""
@@ -75,7 +80,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
7580
grid.south, grid.north + grid.resolution_lat, grid.resolution_lat
7681
)
7782

78-
if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0:
83+
if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0:
7984
lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon)
8085
else:
8186
lon_coords = np.arange(
@@ -193,24 +198,6 @@ def common_coords(
193198
return sorted([str(coord) for coord in coords])
194199

195200

196-
@overload
197-
def call_on_dataset(
198-
func: Callable[..., xr.Dataset],
199-
obj: xr.DataArray,
200-
*args: Any,
201-
**kwargs: Any,
202-
) -> xr.DataArray: ...
203-
204-
205-
@overload
206-
def call_on_dataset(
207-
func: Callable[..., xr.Dataset],
208-
obj: xr.Dataset,
209-
*args: Any,
210-
**kwargs: Any,
211-
) -> xr.Dataset: ...
212-
213-
214201
def call_on_dataset(
215202
func: Callable[..., xr.Dataset],
216203
obj: xr.DataArray | xr.Dataset,
@@ -235,3 +222,193 @@ def call_on_dataset(
235222
return next(iter(result.data_vars.values())).rename(obj.name)
236223

237224
return result
225+
226+
227+
def format_for_regrid(
228+
obj: xr.DataArray | xr.Dataset, target: xr.Dataset
229+
) -> xr.DataArray | xr.Dataset:
230+
"""Apply any pre-formatting to the input dataset to prepare for regridding.
231+
Currently handles padding of spherical geometry if lat/lon coordinates can
232+
be inferred and the domain size requires boundary padding.
233+
"""
234+
orig_chunksizes = obj.chunksizes
235+
236+
# Special-cased coordinates with accepted names and formatting function
237+
coord_handlers: dict[str, CoordHandler] = {
238+
"lat": {"names": ["lat", "latitude"], "func": format_lat},
239+
"lon": {"names": ["lon", "longitude"], "func": format_lon},
240+
}
241+
# Identify coordinates that need to be formatted
242+
formatted_coords = {}
243+
for coord_type, handler in coord_handlers.items():
244+
for coord in obj.coords.keys():
245+
if str(coord).lower() in handler["names"]:
246+
formatted_coords[coord_type] = str(coord)
247+
248+
# Apply formatting
249+
for coord_type, coord in formatted_coords.items():
250+
# Make sure formatted coords are sorted
251+
obj = ensure_monotonic(obj, coord)
252+
target = ensure_monotonic(target, coord)
253+
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
254+
# Coerce back to a single chunk if that's what was passed
255+
if len(orig_chunksizes.get(coord, [])) == 1:
256+
obj = obj.chunk({coord: -1})
257+
258+
return obj
259+
260+
261+
def format_lat(
262+
obj: xr.DataArray | xr.Dataset,
263+
target: xr.Dataset, # noqa ARG001
264+
formatted_coords: dict[str, str],
265+
) -> xr.DataArray | xr.Dataset:
266+
"""If the latitude coordinate is inferred to be global, defined as having
267+
a value within one grid spacing of the poles, and the grid does not natively
268+
have values at -90 and 90, add a single value at each pole computed as the
269+
mean of the first and last latitude bands. This should be roughly equivalent
270+
to the `Pole="all"` option in `ESMF`.
271+
272+
For example, with a grid spacing of 1 degree, and a source grid ranging from
273+
-89.5 to 89.5, the poles would be padded with values at -90 and 90. A grid ranging
274+
from -88 to 88 would not be padded because coverage does not extend all the way
275+
to the poles. A grid ranging from -90 to 90 would also not be padded because the
276+
poles will already be covered in the regridding weights.
277+
"""
278+
lat_coord = formatted_coords["lat"]
279+
lon_coord = formatted_coords.get("lon")
280+
281+
# Concat a padded value representing the mean of the first/last lat bands
282+
# This should match the Pole="all" option of ESMF
283+
# TODO: with cos(90) = 0 weighting, these weights might be 0?
284+
285+
polar_lat = 90
286+
dy = obj.coords[lat_coord].diff(lat_coord).max().values.item()
287+
288+
# Only pad if global but don't have edge values directly at poles
289+
# NOTE: could use xr.pad here instead of xr.concat, but none of the
290+
# modes are an exact fit for this scheme
291+
lat_vals = obj.coords[lat_coord].values
292+
# South pole
293+
if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat:
294+
south_pole = obj.isel({lat_coord: 0})
295+
if lon_coord is not None:
296+
south_pole = south_pole.mean(lon_coord)
297+
obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore
298+
lat_vals = np.concatenate([[-polar_lat], lat_vals])
299+
300+
# North pole
301+
if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat:
302+
north_pole = obj.isel({lat_coord: -1})
303+
if lon_coord is not None:
304+
north_pole = north_pole.mean(lon_coord)
305+
obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore
306+
lat_vals = np.concatenate([lat_vals, [polar_lat]])
307+
308+
obj = update_coord(obj, lat_coord, lat_vals)
309+
310+
return obj
311+
312+
313+
def format_lon(
314+
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str]
315+
) -> xr.DataArray | xr.Dataset:
316+
"""Format the longitude coordinate by shifting the source grid to line up with
317+
the target anywhere in the range of -360 to 360, and then add a single wraparound
318+
padding column if the domain is inferred to be global and the east or west edges
319+
of the target lie outside the source grid centers.
320+
321+
For example, with a source grid ranging from 0.5 to 359.5 and a target grid ranging
322+
from -180 to 180, the source grid would be shifted to -179.5 to 179.5 and then
323+
padded on both the left and right with wraparound values at -180.5 and 180.5 to
324+
provide full coverage for the target edge cells at -180 and 180.
325+
"""
326+
lon_coord = formatted_coords["lon"]
327+
328+
# Find a wrap point outside of the left and right bounds of the target
329+
# This ensures we have coverage on the target and handles global > regional
330+
source_vals = obj.coords[lon_coord].values
331+
target_vals = target.coords[lon_coord].values
332+
wrap_point = (target_vals[-1] + target_vals[0] + 360) / 2
333+
source_vals = np.where(
334+
source_vals < wrap_point - 360, source_vals + 360, source_vals
335+
)
336+
source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals)
337+
obj = update_coord(obj, lon_coord, source_vals)
338+
339+
obj = ensure_monotonic(obj, lon_coord)
340+
341+
# Only pad if domain is global in lon
342+
source_lon = obj.coords[lon_coord]
343+
target_lon = target.coords[lon_coord]
344+
dx_s = source_lon.diff(lon_coord).max().values.item()
345+
dx_t = target_lon.diff(lon_coord).max().values.item()
346+
is_global_lon = source_lon.max().values - source_lon.min().values >= 360 - dx_s
347+
348+
if is_global_lon:
349+
left_pad = (source_lon.values[0] - target_lon.values[0] + dx_t / 2) / dx_s
350+
right_pad = (target_lon.values[-1] - source_lon.values[-1] + dx_t / 2) / dx_s
351+
left_pad = int(np.ceil(np.max([left_pad, 0])))
352+
right_pad = int(np.ceil(np.max([right_pad, 0])))
353+
obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True)
354+
lon_vals = obj.coords[lon_coord].values
355+
if left_pad:
356+
lon_vals[:left_pad] = source_lon.values[-left_pad:] - 360
357+
if right_pad:
358+
lon_vals[-right_pad:] = source_lon.values[:right_pad] + 360
359+
obj = update_coord(obj, lon_coord, lon_vals)
360+
361+
return obj
362+
363+
364+
def coord_is_covered(
365+
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable
366+
) -> bool:
367+
"""Check if the source coord fully covers the target coord."""
368+
pad = target[coord].diff(coord).max().values
369+
left_covered = obj[coord].min() <= target[coord].min() - pad
370+
right_covered = obj[coord].max() >= target[coord].max() + pad
371+
return bool(left_covered.item() and right_covered.item())
372+
373+
374+
@overload
375+
def ensure_monotonic(obj: xr.DataArray, coord: Hashable) -> xr.DataArray: ...
376+
377+
378+
@overload
379+
def ensure_monotonic(obj: xr.Dataset, coord: Hashable) -> xr.Dataset: ...
380+
381+
382+
def ensure_monotonic(
383+
obj: xr.DataArray | xr.Dataset, coord: Hashable
384+
) -> xr.DataArray | xr.Dataset:
385+
"""Ensure that an object has monotonically increasing indexes for a
386+
given coordinate. Only sort and drop duplicates if needed because this
387+
requires reindexing which can be expensive."""
388+
if not obj.indexes[coord].is_monotonic_increasing:
389+
obj = obj.sortby(coord)
390+
if not obj.indexes[coord].is_unique:
391+
obj = obj.drop_duplicates(coord)
392+
return obj
393+
394+
395+
@overload
396+
def update_coord(
397+
obj: xr.DataArray, coord: Hashable, coord_vals: np.ndarray
398+
) -> xr.DataArray: ...
399+
400+
401+
@overload
402+
def update_coord(
403+
obj: xr.Dataset, coord: Hashable, coord_vals: np.ndarray
404+
) -> xr.Dataset: ...
405+
406+
407+
def update_coord(
408+
obj: xr.DataArray | xr.Dataset, coord: Hashable, coord_vals: np.ndarray
409+
) -> xr.DataArray | xr.Dataset:
410+
"""Update the values of a coordinate, ensuring indexes stay in sync."""
411+
attrs = obj.coords[coord].attrs
412+
obj = obj.assign_coords({coord: coord_vals})
413+
obj.coords[coord].attrs = attrs
414+
return obj

0 commit comments

Comments
 (0)