Skip to content

Commit edc2bdc

Browse files
committed
Dask, spherical correction, tests for conservative
1 parent a0dadba commit edc2bdc

File tree

6 files changed

+637
-314
lines changed

6 files changed

+637
-314
lines changed

benchmarks/benchmarking_conservative.ipynb

Lines changed: 250 additions & 243 deletions
Large diffs are not rendered by default.
-11.6 MB
Binary file not shown.

src/xarray_regrid/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from xarray_regrid.regrid import Regridder
1+
from xarray_regrid.regrid import DataArrayRegridder, DatasetRegridder
22
from xarray_regrid.utils import Grid, create_regridding_dataset
33

4-
__all__ = ["Grid", "Regridder", "create_regridding_dataset"]
4+
__all__ = [
5+
"Grid",
6+
"DataArrayRegridder",
7+
"DatasetRegridder",
8+
"create_regridding_dataset",
9+
]

src/xarray_regrid/methods.py

Lines changed: 182 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,36 @@
1-
from typing import Literal
1+
from collections.abc import Hashable
2+
from typing import Literal, overload
23

4+
import dask.array
5+
import numpy as np
36
import xarray as xr
47

58
from xarray_regrid import utils
69

710

11+
@overload
12+
def interp_regrid(
13+
data: xr.DataArray,
14+
target_ds: xr.Dataset,
15+
method: Literal["linear", "nearest", "cubic"],
16+
) -> xr.DataArray:
17+
...
18+
19+
20+
@overload
821
def interp_regrid(
922
data: xr.Dataset,
1023
target_ds: xr.Dataset,
1124
method: Literal["linear", "nearest", "cubic"],
1225
) -> xr.Dataset:
26+
...
27+
28+
29+
def interp_regrid(
30+
data: xr.DataArray | xr.Dataset,
31+
target_ds: xr.Dataset,
32+
method: Literal["linear", "nearest", "cubic"],
33+
) -> xr.DataArray | xr.Dataset:
1334
"""Refine a dataset using xarray's interp method.
1435
1536
Args:
@@ -29,10 +50,29 @@ def interp_regrid(
2950
)
3051

3152

53+
@overload
54+
def conservative_regrid(
55+
data: xr.DataArray,
56+
target_ds: xr.Dataset,
57+
latitude_coord: str | None,
58+
) -> xr.DataArray:
59+
...
60+
61+
62+
@overload
3263
def conservative_regrid(
3364
data: xr.Dataset,
3465
target_ds: xr.Dataset,
66+
latitude_coord: str | None,
3567
) -> xr.Dataset:
68+
...
69+
70+
71+
def conservative_regrid(
72+
data: xr.DataArray | xr.Dataset,
73+
target_ds: xr.Dataset,
74+
latitude_coord: str | None,
75+
) -> xr.DataArray | xr.Dataset:
3676
"""Refine a dataset using conservative regridding.
3777
3878
The method implementation is based on a post by Stephan Hoyer; "For the case of
@@ -49,34 +89,156 @@ def conservative_regrid(
4989
Returns:
5090
Regridded input dataset
5191
"""
92+
if latitude_coord is not None:
93+
if latitude_coord not in data.coords:
94+
msg = "Latitude coord not in input data!"
95+
raise ValueError(msg)
96+
else:
97+
latitude_coord = ""
98+
99+
dim_order = list(target_ds.dims)
100+
52101
coord_names = set(target_ds.coords).intersection(set(data.coords))
53102
coords = {name: target_ds[name] for name in coord_names}
54103
data = data.sortby(list(coord_names))
55104

56-
# TODO: filter out data vars lacking the target coordinates
105+
if isinstance(data, xr.Dataset):
106+
return conservative_regrid_dataset(data, coords, latitude_coord).transpose(
107+
*dim_order, ...
108+
)
109+
else:
110+
return conservative_regrid_dataarray(data, coords, latitude_coord).transpose(
111+
*dim_order, ...
112+
)
113+
114+
115+
def conservative_regrid_dataset(
116+
data: xr.Dataset,
117+
coords: dict[Hashable, xr.DataArray],
118+
latitude_coord: str,
119+
) -> xr.Dataset:
120+
"""Dataset implementation of the conservative regridding method."""
57121
data_vars: list[str] = list(data.data_vars)
58122
dataarrays = [data[var] for var in data_vars]
59123

60124
for coord in coords:
61125
target_coords = coords[coord].to_numpy()
62-
# TODO: better resolution/IntervalIndex inference
63-
target_intervals = utils.to_intervalindex(
64-
target_coords, resolution=target_coords[1] - target_coords[0]
65-
)
66126
source_coords = data[coord].to_numpy()
67-
source_intervals = utils.to_intervalindex(
68-
source_coords, resolution=source_coords[1] - source_coords[0]
69-
)
70-
overlap = utils.overlap(source_intervals, target_intervals)
71-
weights = utils.normalize_overlap(overlap)
127+
weights = get_weights(source_coords, target_coords)
128+
129+
# Modify weights to correct for latitude distortion
130+
if str(coord) == latitude_coord:
131+
dot_array = utils.create_dot_dataarray(
132+
weights, str(coord), target_coords, source_coords
133+
)
134+
dot_array = apply_spherical_correction(dot_array, latitude_coord)
135+
weights = dot_array.to_numpy()
136+
137+
for i in range(len(dataarrays)):
138+
if coord in dataarrays[i].coords:
139+
da = dataarrays[i].transpose(coord, ...)
140+
dataarrays[i] = apply_weights(da, weights, coord, target_coords)
72141

73-
# TODO: Use `sparse.COO(weights)`. xr.dot does not support this. Much faster!
74-
dot_array = utils.create_dot_dataarray(
75-
weights, str(coord), target_coords, source_coords
76-
)
77-
# TODO: modify weights to correct for latitude.
78-
dataarrays = [
79-
xr.dot(da, dot_array).rename({f"target_{coord}": coord}).rename(da.name)
80-
for da in dataarrays
81-
]
82142
return xr.merge(dataarrays) # TODO: add other coordinates/data variables back in.
143+
144+
145+
def conservative_regrid_dataarray(
146+
data: xr.DataArray,
147+
coords: dict[Hashable, xr.DataArray],
148+
latitude_coord: str,
149+
) -> xr.DataArray:
150+
"""DataArray implementation of the conservative regridding method."""
151+
for coord in coords:
152+
if coord in data.coords:
153+
target_coords = coords[coord].to_numpy()
154+
source_coords = data[coord].to_numpy()
155+
156+
weights = get_weights(source_coords, target_coords)
157+
158+
# Modify weights to correct for latitude distortion
159+
if str(coord) == latitude_coord:
160+
dot_array = utils.create_dot_dataarray(
161+
weights, str(coord), target_coords, source_coords
162+
)
163+
dot_array = apply_spherical_correction(dot_array, latitude_coord)
164+
weights = dot_array.to_numpy()
165+
166+
data = data.transpose(coord, ...)
167+
data = apply_weights(data, weights, coord, target_coords)
168+
169+
return data
170+
171+
172+
def apply_weights(
173+
da: xr.DataArray, weights: np.ndarray, coord_name: Hashable, new_coords: np.ndarray
174+
) -> xr.DataArray:
175+
"""Apply the weights to convert data to the new coordinates."""
176+
if da.chunks is not None:
177+
# Dask routine
178+
new_data = dask.array.einsum(
179+
"i...,ij->j...", da.data, weights, optimize="greedy"
180+
)
181+
else:
182+
# numpy routine
183+
new_data = np.einsum("i...,ij->j...", da.data, weights)
184+
185+
coord_mapping = {coord_name: new_coords}
186+
coords = list(da.dims)
187+
coords.remove(coord_name)
188+
for coord in coords:
189+
coord_mapping[coord] = da[coord].to_numpy()
190+
191+
return xr.DataArray(
192+
data=new_data,
193+
coords=coord_mapping,
194+
name=da.name,
195+
)
196+
197+
198+
def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray:
199+
"""Determine the weights to map from the old coordinates to the new coordinates.
200+
201+
Args:
202+
source_coords: Source coordinates (center points)
203+
target_coords Target coordinates (center points)
204+
205+
Returns:
206+
Weights, which can be used with a dot product to apply the conservative regrid.
207+
"""
208+
# TODO: better resolution/IntervalIndex inference
209+
target_intervals = utils.to_intervalindex(
210+
target_coords, resolution=target_coords[1] - target_coords[0]
211+
)
212+
213+
source_intervals = utils.to_intervalindex(
214+
source_coords, resolution=source_coords[1] - source_coords[0]
215+
)
216+
overlap = utils.overlap(source_intervals, target_intervals)
217+
return utils.normalize_overlap(overlap)
218+
219+
220+
def apply_spherical_correction(
221+
dot_array: xr.DataArray, latitude_coord: str
222+
) -> xr.DataArray:
223+
"""Apply a sperical earth correction on the prepared dot product weights."""
224+
da = dot_array.copy()
225+
latitude_res = np.median(np.diff(dot_array[latitude_coord].to_numpy(), 1))
226+
lat_weights = lat_weight(dot_array[latitude_coord].to_numpy(), latitude_res)
227+
da.values = utils.normalize_overlap(dot_array.values * lat_weights[:, np.newaxis])
228+
return da
229+
230+
231+
def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray:
232+
"""Return the weight of gridcells based on their latitude.
233+
234+
Args:
235+
latitude: (Center) latitude values of the gridcells, in degrees.
236+
latitude_res: Resolution/width of the grid cells, in degrees.
237+
238+
Returns:
239+
Weights, same shape as latitude input.
240+
"""
241+
dlat: float = np.radians(latitude_res)
242+
lat = np.radians(latitude)
243+
h = np.sin(lat + dlat / 2) - np.sin(lat - dlat / 2)
244+
return h * dlat / (np.pi * 4) # type: ignore

0 commit comments

Comments
 (0)