Skip to content

Commit 7937260

Browse files
Refactor methods to module. Fix conservative nan bug. (#19)
* Refactor methods to module. Fix conservative nan bug. * Fix linting and typing issues
1 parent 56b60fa commit 7937260

File tree

6 files changed

+115
-52
lines changed

6 files changed

+115
-52
lines changed

src/xarray_regrid/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from xarray_regrid import methods
12
from xarray_regrid.regrid import Regridder
23
from xarray_regrid.utils import Grid, create_regridding_dataset
34

45
__all__ = [
56
"Grid",
67
"Regridder",
78
"create_regridding_dataset",
9+
"methods",
810
]
11+
12+
__version__ = "0.2.0"
Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
"""Conservative regridding implementation."""
12
from collections.abc import Hashable
2-
from typing import Literal, overload
3+
from typing import overload
34

45
import dask.array
56
import numpy as np
@@ -8,48 +9,6 @@
89
from xarray_regrid import utils
910

1011

11-
@overload
12-
def interp_regrid(
13-
data: xr.DataArray,
14-
target_ds: xr.Dataset,
15-
method: Literal["linear", "nearest", "cubic"],
16-
) -> xr.DataArray:
17-
...
18-
19-
20-
@overload
21-
def interp_regrid(
22-
data: xr.Dataset,
23-
target_ds: xr.Dataset,
24-
method: Literal["linear", "nearest", "cubic"],
25-
) -> xr.Dataset:
26-
...
27-
28-
29-
def interp_regrid(
30-
data: xr.DataArray | xr.Dataset,
31-
target_ds: xr.Dataset,
32-
method: Literal["linear", "nearest", "cubic"],
33-
) -> xr.DataArray | xr.Dataset:
34-
"""Refine a dataset using xarray's interp method.
35-
36-
Args:
37-
data: Input dataset.
38-
target_ds: Dataset which coordinates the input dataset should be regrid to.
39-
method: Which interpolation method to use (e.g. 'linear', 'nearest').
40-
41-
Returns:
42-
Regridded input dataset
43-
"""
44-
coord_names = set(target_ds.coords).intersection(set(data.coords))
45-
coords = {name: target_ds[name] for name in coord_names}
46-
47-
return data.interp(
48-
coords=coords,
49-
method=method,
50-
)
51-
52-
5312
@overload
5413
def conservative_regrid(
5514
data: xr.DataArray,
@@ -173,14 +132,13 @@ def apply_weights(
173132
da: xr.DataArray, weights: np.ndarray, coord_name: Hashable, new_coords: np.ndarray
174133
) -> xr.DataArray:
175134
"""Apply the weights to convert data to the new coordinates."""
135+
new_data: np.ndarray | dask.array.Array
176136
if da.chunks is not None:
177137
# Dask routine
178-
new_data = dask.array.einsum(
179-
"i...,ij->j...", da.data, weights, optimize="greedy"
180-
)
138+
new_data = compute_einsum_dask(da, weights)
181139
else:
182140
# numpy routine
183-
new_data = np.einsum("i...,ij->j...", da.data, weights)
141+
new_data = compute_einsum_numpy(da, weights)
184142

185143
coord_mapping = {coord_name: new_coords}
186144
coords = list(da.dims)
@@ -195,6 +153,36 @@ def apply_weights(
195153
)
196154

197155

156+
def compute_einsum_dask(da: xr.DataArray, weights: np.ndarray) -> dask.array.Array:
157+
"""Compute the einsum between dask data and weights, and mask NaNs if needed."""
158+
new_data: dask.array.Array
159+
if np.any(np.isnan(da.data)):
160+
new_data = dask.array.einsum(
161+
"i...,ij->j...", da.fillna(0).data, weights, optimize="greedy"
162+
)
163+
isnan = dask.array.einsum(
164+
"i...,ij->j...", np.isnan(da.data), weights, optimize="greedy"
165+
)
166+
new_data[isnan > 0] = np.nan
167+
else:
168+
new_data = dask.array.einsum(
169+
"i...,ij->j...", da.data, weights, optimize="greedy"
170+
)
171+
return new_data
172+
173+
174+
def compute_einsum_numpy(da: xr.DataArray, weights: np.ndarray) -> np.ndarray:
175+
"""Compute the einsum between numpy data and weights, and mask NaNs if needed."""
176+
new_data: np.ndarray
177+
if np.any(np.isnan(da.data)):
178+
new_data = np.einsum("i...,ij->j...", da.fillna(0).data, weights)
179+
isnan = np.einsum("i...,ij->j...", np.isnan(da.data), weights)
180+
new_data[isnan > 0] = np.nan
181+
else:
182+
new_data = np.einsum("i...,ij->j...", da.data, weights)
183+
return new_data
184+
185+
198186
def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray:
199187
"""Determine the weights to map from the old coordinates to the new coordinates.
200188
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Methods based on xr.interp."""
2+
from typing import Literal, overload
3+
4+
import xarray as xr
5+
6+
7+
@overload
8+
def interp_regrid(
9+
data: xr.DataArray,
10+
target_ds: xr.Dataset,
11+
method: Literal["linear", "nearest", "cubic"],
12+
) -> xr.DataArray:
13+
...
14+
15+
16+
@overload
17+
def interp_regrid(
18+
data: xr.Dataset,
19+
target_ds: xr.Dataset,
20+
method: Literal["linear", "nearest", "cubic"],
21+
) -> xr.Dataset:
22+
...
23+
24+
25+
def interp_regrid(
26+
data: xr.DataArray | xr.Dataset,
27+
target_ds: xr.Dataset,
28+
method: Literal["linear", "nearest", "cubic"],
29+
) -> xr.DataArray | xr.Dataset:
30+
"""Refine a dataset using xarray's interp method.
31+
32+
Args:
33+
data: Input dataset.
34+
target_ds: Dataset which coordinates the input dataset should be regrid to.
35+
method: Which interpolation method to use (e.g. 'linear', 'nearest').
36+
37+
Returns:
38+
Regridded input dataset
39+
"""
40+
coord_names = set(target_ds.coords).intersection(set(data.coords))
41+
coords = {name: target_ds[name] for name in coord_names}
42+
43+
return data.interp(
44+
coords=coords,
45+
method=method,
46+
)

src/xarray_regrid/regrid.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import xarray as xr
22

3-
from xarray_regrid import methods, most_common
3+
from xarray_regrid.methods import conservative, interp, most_common
44

55

66
@xr.register_dataarray_accessor("regrid")
@@ -34,7 +34,7 @@ def linear(
3434
Data regridded to the target dataset coordinates.
3535
"""
3636
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
37-
return methods.interp_regrid(self._obj, ds_target_grid, "linear")
37+
return interp.interp_regrid(self._obj, ds_target_grid, "linear")
3838

3939
def nearest(
4040
self,
@@ -51,7 +51,7 @@ def nearest(
5151
Data regridded to the target dataset coordinates.
5252
"""
5353
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
54-
return methods.interp_regrid(self._obj, ds_target_grid, "nearest")
54+
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")
5555

5656
def cubic(
5757
self,
@@ -68,7 +68,7 @@ def cubic(
6868
Returns:
6969
Data regridded to the target dataset coordinates.
7070
"""
71-
return methods.interp_regrid(self._obj, ds_target_grid, "cubic")
71+
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")
7272

7373
def conservative(
7474
self,
@@ -89,7 +89,9 @@ def conservative(
8989
"""
9090

9191
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
92-
return methods.conservative_regrid(self._obj, ds_target_grid, latitude_coord)
92+
return conservative.conservative_regrid(
93+
self._obj, ds_target_grid, latitude_coord
94+
)
9395

9496
def most_common(
9597
self,

tests/test_regrid.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,26 @@ def test_conservative_regridder(conservative_input_data, conservative_sample_gri
100100
rtol=0.002,
101101
atol=2e-6,
102102
)
103+
104+
105+
def test_conservative_nans(conservative_input_data, conservative_sample_grid):
106+
ds = conservative_input_data
107+
ds["tp"] = ds["tp"].where(ds.latitude >= 0).where(ds.longitude < 180)
108+
ds_regrid = ds.regrid.conservative(
109+
conservative_sample_grid, latitude_coord="latitude"
110+
)
111+
ds_cdo = xr.open_dataset(CDO_DATA["conservative"])
112+
113+
# Cut of the edges: edge performance to be improved later (hopefully)
114+
no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)}
115+
no_nans = {"latitude": slice(1, 90), "longitude": slice(None, 179)}
116+
xr.testing.assert_allclose(
117+
ds_regrid["tp"]
118+
.sel(no_edges)
119+
.sel(no_nans)
120+
.compute()
121+
.transpose("time", "latitude", "longitude"),
122+
ds_cdo["tp"].sel(no_edges).sel(no_nans).compute(),
123+
rtol=0.002,
124+
atol=2e-6,
125+
)

0 commit comments

Comments
 (0)