Skip to content

Commit 432b847

Browse files
Allow descending ordered coordinates in target grid (#29)
* Added sortby to target grid * Output grid matches input target grid order * Added coord order tests * Linting fixes * Make coord order tests more compact * Use reindex_like for aligning coords, make most_common test more compact --------- Co-authored-by: Bart Schilperoort <[email protected]>
1 parent 9a2db21 commit 432b847

File tree

4 files changed

+100
-13
lines changed

4 files changed

+100
-13
lines changed

src/xarray_regrid/methods/conservative.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,22 @@ def conservative_regrid(
5858
dim_order = list(target_ds.dims)
5959

6060
coord_names = set(target_ds.coords).intersection(set(data.coords))
61-
coords = {name: target_ds[name] for name in coord_names}
61+
target_ds_sorted = target_ds.sortby(list(coord_names))
62+
coords = {name: target_ds_sorted[name] for name in coord_names}
6263
data = data.sortby(list(coord_names))
6364

6465
if isinstance(data, xr.Dataset):
65-
return conservative_regrid_dataset(data, coords, latitude_coord).transpose(
66-
*dim_order, ...
67-
)
66+
regridded_data = conservative_regrid_dataset(
67+
data, coords, latitude_coord
68+
).transpose(*dim_order, ...)
6869
else:
69-
return conservative_regrid_dataarray(data, coords, latitude_coord).transpose(
70-
*dim_order, ...
71-
)
70+
regridded_data = conservative_regrid_dataarray( # type: ignore
71+
data, coords, latitude_coord
72+
).transpose(*dim_order, ...)
73+
74+
regridded_data = regridded_data.reindex_like(target_ds, copy=False)
75+
76+
return regridded_data
7277

7378

7479
def conservative_regrid_dataset(

src/xarray_regrid/methods/most_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,18 @@ def most_common_wrapper(
6161
data = data.to_dataset(name=da_name)
6262

6363
coords = utils.common_coords(data, target_ds)
64+
target_ds_sorted = target_ds.sortby(list(coords))
6465
coord_size = [data[coord].size for coord in coords]
6566
mem_usage = np.prod(coord_size) * np.zeros((1,), dtype=np.int64).itemsize
6667

6768
if max_mem is not None and mem_usage > max_mem:
6869
result = split_combine_most_common(
69-
data=data, target_ds=target_ds, time_dim=time_dim, max_mem=max_mem
70+
data=data, target_ds=target_ds_sorted, time_dim=time_dim, max_mem=max_mem
7071
)
7172
else:
72-
result = most_common(data=data, target_ds=target_ds, time_dim=time_dim)
73+
result = most_common(data=data, target_ds=target_ds_sorted, time_dim=time_dim)
74+
75+
result = result.reindex_like(target_ds, copy=False)
7376

7477
if da_name is not None:
7578
return result[da_name]

tests/test_most_common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
import xarray as xr
4+
from numpy.testing import assert_array_equal
45

56
from xarray_regrid import Grid, create_regridding_dataset
67

@@ -95,3 +96,21 @@ def test_attrs_dataset(dummy_lc_data, dummy_target_grid):
9596
assert ds_regrid.attrs != {}
9697
assert ds_regrid.attrs == dummy_lc_data.attrs
9798
assert ds_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs
99+
100+
101+
@pytest.mark.parametrize("dataarray", [True, False])
102+
def test_coord_order_original(dummy_lc_data, dummy_target_grid, dataarray):
103+
input_data = dummy_lc_data["lc"] if dataarray else dummy_lc_data
104+
ds_regrid = input_data.regrid.most_common(dummy_target_grid)
105+
assert_array_equal(ds_regrid["latitude"], dummy_target_grid["latitude"])
106+
assert_array_equal(ds_regrid["longitude"], dummy_target_grid["longitude"])
107+
108+
109+
@pytest.mark.parametrize("coord", ["latitude", "longitude"])
110+
@pytest.mark.parametrize("dataarray", [True, False])
111+
def test_coord_order_reversed(dummy_lc_data, dummy_target_grid, coord, dataarray):
112+
input_data = dummy_lc_data["lc"] if dataarray else dummy_lc_data
113+
dummy_target_grid[coord] = list(reversed(dummy_target_grid[coord]))
114+
ds_regrid = input_data.regrid.most_common(dummy_target_grid)
115+
assert_array_equal(ds_regrid["latitude"], dummy_target_grid["latitude"])
116+
assert_array_equal(ds_regrid["longitude"], dummy_target_grid["longitude"])

tests/test_regrid.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from copy import deepcopy
12
from pathlib import Path
23

34
import pytest
45
import xarray as xr
6+
from numpy.testing import assert_array_equal
57

68
import xarray_regrid
79

@@ -14,9 +16,15 @@
1416
}
1517

1618

19+
@pytest.fixture(scope="session")
20+
def load_input_data() -> xr.Dataset:
21+
ds = xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc")
22+
return ds.compute()
23+
24+
1725
@pytest.fixture
18-
def sample_input_data() -> xr.Dataset:
19-
return xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc")
26+
def sample_input_data(load_input_data) -> xr.Dataset:
27+
return deepcopy(load_input_data)
2028

2129

2230
@pytest.fixture
@@ -63,9 +71,15 @@ def test_basic_regridders_da(sample_input_data, sample_grid_ds, method, cdo_file
6371
xr.testing.assert_allclose(da_regrid.compute(), ds_cdo["d2m"].compute())
6472

6573

74+
@pytest.fixture(scope="session")
75+
def load_conservative_input_data() -> xr.Dataset:
76+
ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc")
77+
return ds.compute()
78+
79+
6680
@pytest.fixture
67-
def conservative_input_data() -> xr.Dataset:
68-
return xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc")
81+
def conservative_input_data(load_conservative_input_data) -> xr.Dataset:
82+
return deepcopy(load_conservative_input_data)
6983

7084

7185
@pytest.fixture
@@ -156,3 +170,49 @@ def test_attrs_dataset_conservative(sample_input_data, sample_grid_ds):
156170
assert ds_regrid.attrs == sample_input_data.attrs
157171
assert ds_regrid["d2m"].attrs == sample_input_data["d2m"].attrs
158172
assert ds_regrid["longitude"].attrs == sample_input_data["longitude"].attrs
173+
174+
175+
class TestCoordOrder:
176+
@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
177+
@pytest.mark.parametrize("dataarray", [True, False])
178+
def test_original(self, sample_input_data, sample_grid_ds, method, dataarray):
179+
input_data = sample_input_data["d2m"] if dataarray else sample_input_data
180+
regridder = getattr(input_data.regrid, method)
181+
ds_regrid = regridder(sample_grid_ds)
182+
assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"])
183+
assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"])
184+
185+
@pytest.mark.parametrize("coord", ["latitude", "longitude"])
186+
@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
187+
@pytest.mark.parametrize("dataarray", [True, False])
188+
def test_reversed(
189+
self, sample_input_data, sample_grid_ds, method, coord, dataarray
190+
):
191+
input_data = sample_input_data["d2m"] if dataarray else sample_input_data
192+
regridder = getattr(input_data.regrid, method)
193+
sample_grid_ds[coord] = list(reversed(sample_grid_ds[coord]))
194+
ds_regrid = regridder(sample_grid_ds)
195+
assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"])
196+
assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"])
197+
198+
@pytest.mark.parametrize("dataarray", [True, False])
199+
def test_conservative_original(self, sample_input_data, sample_grid_ds, dataarray):
200+
input_data = sample_input_data["d2m"] if dataarray else sample_input_data
201+
ds_regrid = input_data.regrid.conservative(
202+
sample_grid_ds, latitude_coord="latitude"
203+
)
204+
assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"])
205+
assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"])
206+
207+
@pytest.mark.parametrize("coord", ["latitude", "longitude"])
208+
@pytest.mark.parametrize("dataarray", [True, False])
209+
def test_conservative_reversed(
210+
self, sample_input_data, sample_grid_ds, coord, dataarray
211+
):
212+
input_data = sample_input_data["d2m"] if dataarray else sample_input_data
213+
sample_grid_ds[coord] = list(reversed(sample_grid_ds[coord]))
214+
ds_regrid = input_data.regrid.conservative(
215+
sample_grid_ds, latitude_coord="latitude"
216+
)
217+
assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"])
218+
assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"])

0 commit comments

Comments
 (0)