Skip to content

Commit bc7be5b

Browse files
authored
Sparse weights in conservative method (#49)
* add weights sparsity * add opt-einsum * fix linting * make sparse optional * shuffle dependency groups * add readme note, test other methods with varying chunks * update changelog
1 parent 9d71962 commit bc7be5b

File tree

6 files changed

+85
-19
lines changed

6 files changed

+85
-19
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
66

77
## Unreleased
88

9-
# 0.3.0 (2024-09-05)
9+
Added:
10+
- If latitude/longitude coordinates are detected and the domain is global, apply automatic padding at the boundaries, which gives behavior more consistent with common tools like ESMF and CDO ([#45](https://github.com/xarray-contrib/xarray-regrid/pull/45)).
11+
- Conservative regridding weights are converted to sparse matrices if the optional [sparse](https://github.com/pydata/sparse) package is installed, which improves compute and memory performance in most cases ([#49](https://github.com/xarray-contrib/xarray-regrid/pull/49)).
12+
13+
14+
## 0.3.0 (2024-09-05)
1015

1116
New contributors:
1217
- [@slevang](https://github.com/slevang)

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,23 @@ Regridding is a common operation in earth science and other fields. While xarray
2424

2525
## Installation
2626

27+
For a minimal install:
2728
```console
2829
pip install xarray-regrid
2930
```
3031

32+
To improve performance in certain cases:
33+
```console
34+
pip install xarray-regrid[accel]
35+
```
36+
37+
which includes optional extras such as:
38+
- `dask`: parallelization over chunked data
39+
- `sparse`: for performing conservative regridding using sparse weight matrices
40+
- `opt-einsum`: optimized einsum routines used in conservative regridding
41+
42+
Benchmarking varies across different hardware specifications, but the inclusion of these extras can often provide significant speedups.
43+
3144
## Usage
3245
The xarray-regrid routines are accessed using the "regrid" accessor on an xarray Dataset:
3346
```py

environment.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,4 @@ dependencies:
77
- xESMF
88
- cdo
99
- pip:
10-
- "xarray-regrid"
11-
- "cftime"
12-
- "matplotlib"
13-
- "dask[distributed]"
10+
- "xarray-regrid[accel,benchmarking,dev]"

pyproject.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,27 @@ Issues = "https://github.com/EXCITED-CO2/xarray-regrid/issues"
3939
Source = "https://github.com/EXCITED-CO2/xarray-regrid"
4040

4141
[project.optional-dependencies]
42-
benchmarking = [
42+
accel = [
43+
"sparse",
44+
"opt-einsum",
4345
"dask[distributed]",
46+
]
47+
benchmarking = [
4448
"matplotlib",
4549
"zarr",
50+
"h5netcdf",
4651
"requests",
4752
"aiohttp",
53+
"pooch",
54+
"cftime", # required for decode time of test netCDF files
4855
]
4956
dev = [
5057
"hatch",
5158
"ruff",
5259
"mypy",
5360
"pytest",
5461
"pytest-cov",
55-
"cftime", # required for decode time of test netCDF files
5662
"pandas-stubs", # Adds typing for pandas.
57-
"cftime",
5863
]
5964
docs = [ # Required for ReadTheDocs
6065
"myst_parser",
@@ -69,7 +74,7 @@ docs = [ # Required for ReadTheDocs
6974
path = "src/xarray_regrid/__init__.py"
7075

7176
[tool.hatch.envs.default]
72-
features = ["dev", "benchmarking"]
77+
features = ["accel", "dev", "benchmarking"]
7378

7479
[tool.hatch.envs.default.scripts]
7580
lint = [

src/xarray_regrid/methods/conservative.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import numpy as np
77
import xarray as xr
88

9+
try:
10+
import sparse # type: ignore
11+
except ImportError:
12+
sparse = None
13+
914
from xarray_regrid import utils
1015

1116
EMPTY_DA_NAME = "FRAC_EMPTY"
@@ -125,9 +130,14 @@ def conservative_regrid_dataset(
125130

126131
for array in data_vars.keys():
127132
if coord in data_vars[array].dims:
133+
if sparse is not None:
134+
var_weights = sparsify_weights(weights, data_vars[array])
135+
else:
136+
var_weights = weights
137+
128138
data_vars[array], valid_fracs[array] = apply_weights(
129139
da=data_vars[array],
130-
weights=weights,
140+
weights=var_weights,
131141
coord=coord,
132142
valid_frac=valid_fracs[array],
133143
skipna=skipna,
@@ -171,7 +181,9 @@ def apply_weights(
171181
# Renormalize the weights along this dim by the accumulated valid_frac
172182
# along previous dimensions
173183
if valid_frac.name != EMPTY_DA_NAME:
174-
weights_norm = weights * valid_frac / valid_frac.mean(dim=[coord])
184+
weights_norm = weights * (valid_frac / valid_frac.mean(dim=[coord])).fillna(
185+
0
186+
)
175187

176188
da_reduced: xr.DataArray = xr.dot(
177189
da.fillna(0), weights_norm, dim=[coord], optimize=True
@@ -180,7 +192,7 @@ def apply_weights(
180192

181193
if skipna:
182194
weights_valid_sum: xr.DataArray = xr.dot(
183-
weights_norm, notnull, dim=[coord], optimize=True
195+
notnull, weights_norm, dim=[coord], optimize=True
184196
)
185197
weights_valid_sum = weights_valid_sum.rename(coord_map)
186198
da_reduced /= weights_valid_sum.clip(1e-6, None)
@@ -195,6 +207,17 @@ def apply_weights(
195207
valid_frac = valid_frac.rename(coord_map)
196208
valid_frac = valid_frac.clip(0, 1)
197209

210+
# In some cases, dot product of dask data and sparse weights fails
211+
# to automatically densify, which prevents future conversion to numpy
212+
if (
213+
sparse is not None
214+
and da_reduced.chunks
215+
and isinstance(da_reduced.data._meta, sparse.COO)
216+
):
217+
da_reduced.data = da_reduced.data.map_blocks(
218+
lambda x: x.todense(), dtype=da_reduced.dtype
219+
)
220+
198221
return da_reduced, valid_frac
199222

200223

@@ -248,3 +271,17 @@ def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray:
248271
lat = np.radians(latitude)
249272
h = np.sin(lat + dlat / 2) - np.sin(lat - dlat / 2)
250273
return h * dlat / (np.pi * 4) # type: ignore
274+
275+
276+
def sparsify_weights(weights: xr.DataArray, da: xr.DataArray) -> xr.DataArray:
277+
"""Create a sparse version of the weights that matches the dtype and chunks
278+
of the array to be regridded. Even though the weights can be constructed as
279+
dense arrays, contraction is more efficient with sparse operations."""
280+
new_weights = weights.copy().astype(da.dtype)
281+
if da.chunks:
282+
chunks = {k: v for k, v in da.chunksizes.items() if k in weights.dims}
283+
new_weights.data = new_weights.chunk(chunks).data.map_blocks(sparse.COO)
284+
else:
285+
new_weights.data = sparse.COO(weights.data)
286+
287+
return new_weights

tests/test_regrid.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"conservative": DATA_PATH / "cdo_conservative_64b.nc",
2121
}
2222

23+
CHUNK_SCHEMES = [{}, {"time": 1}, {"longitude": 100, "latitude": 100}]
24+
2325

2426
@pytest.fixture(scope="session")
2527
def sample_input_data() -> xr.Dataset:
@@ -71,31 +73,38 @@ def conservative_sample_grid():
7173

7274

7375
@pytest.mark.parametrize("method", ["linear", "nearest"])
76+
@pytest.mark.parametrize("chunks", CHUNK_SCHEMES)
7477
def test_basic_regridders_ds(
75-
sample_input_data, sample_grid_ds, cdo_comparison_data, method
78+
sample_input_data, sample_grid_ds, cdo_comparison_data, method, chunks
7679
):
7780
"""Test the dataset regridders (except conservative)."""
78-
regridder = getattr(sample_input_data.regrid, method)
81+
regridder = getattr(sample_input_data.chunk(chunks).regrid, method)
7982
ds_regrid = regridder(sample_grid_ds)
8083
ds_cdo = cdo_comparison_data[method]
8184
xr.testing.assert_allclose(ds_regrid, ds_cdo, rtol=0.002, atol=2e-5)
8285

8386

8487
@pytest.mark.parametrize("method", ["linear", "nearest"])
88+
@pytest.mark.parametrize("chunks", CHUNK_SCHEMES)
8589
def test_basic_regridders_da(
86-
sample_input_data, sample_grid_ds, cdo_comparison_data, method
90+
sample_input_data, sample_grid_ds, cdo_comparison_data, method, chunks
8791
):
8892
"""Test the dataarray regridders (except conservative)."""
89-
regridder = getattr(sample_input_data["d2m"].regrid, method)
93+
regridder = getattr(sample_input_data["d2m"].chunk(chunks).regrid, method)
9094
da_regrid = regridder(sample_grid_ds)
9195
da_cdo = cdo_comparison_data[method]["d2m"]
9296
xr.testing.assert_allclose(da_regrid, da_cdo, rtol=0.002, atol=2e-5)
9397

9498

99+
@pytest.mark.parametrize("chunks", CHUNK_SCHEMES)
95100
def test_conservative_regridder(
96-
conservative_input_data, conservative_sample_grid, cdo_comparison_data
101+
conservative_input_data,
102+
conservative_sample_grid,
103+
cdo_comparison_data,
104+
chunks,
97105
):
98-
ds_regrid = conservative_input_data.regrid.conservative(
106+
input_data = conservative_input_data.chunk(chunks)
107+
ds_regrid = input_data.regrid.conservative(
99108
conservative_sample_grid, latitude_coord="latitude"
100109
)
101110
ds_cdo = cdo_comparison_data["conservative"]
@@ -201,7 +210,7 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold):
201210

202211
@pytest.mark.skipif(xesmf is None, reason="xesmf required")
203212
def test_conservative_nan_thresholds_against_xesmf():
204-
ds = xr.tutorial.open_dataset("ersstv5").sst.compute().isel(time=[0])
213+
ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).compute()
205214
ds = ds.rename(lon="longitude", lat="latitude")
206215
new_grid = xarray_regrid.Grid(
207216
north=90,

0 commit comments

Comments
 (0)