|
| 1 | +from copy import deepcopy |
1 | 2 | from pathlib import Path |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 | import xarray as xr |
| 6 | +from numpy.testing import assert_array_equal |
5 | 7 |
|
6 | 8 | import xarray_regrid |
7 | 9 |
|
|
14 | 16 | } |
15 | 17 |
|
16 | 18 |
|
| 19 | +@pytest.fixture(scope="session") |
| 20 | +def load_input_data() -> xr.Dataset: |
| 21 | + ds = xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc") |
| 22 | + return ds.compute() |
| 23 | + |
| 24 | + |
17 | 25 | @pytest.fixture |
18 | | -def sample_input_data() -> xr.Dataset: |
19 | | - return xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc") |
| 26 | +def sample_input_data(load_input_data) -> xr.Dataset: |
| 27 | + return deepcopy(load_input_data) |
20 | 28 |
|
21 | 29 |
|
22 | 30 | @pytest.fixture |
@@ -63,9 +71,15 @@ def test_basic_regridders_da(sample_input_data, sample_grid_ds, method, cdo_file |
63 | 71 | xr.testing.assert_allclose(da_regrid.compute(), ds_cdo["d2m"].compute()) |
64 | 72 |
|
65 | 73 |
|
| 74 | +@pytest.fixture(scope="session") |
| 75 | +def load_conservative_input_data() -> xr.Dataset: |
| 76 | + ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") |
| 77 | + return ds.compute() |
| 78 | + |
| 79 | + |
66 | 80 | @pytest.fixture |
67 | | -def conservative_input_data() -> xr.Dataset: |
68 | | - return xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") |
| 81 | +def conservative_input_data(load_conservative_input_data) -> xr.Dataset: |
| 82 | + return deepcopy(load_conservative_input_data) |
69 | 83 |
|
70 | 84 |
|
71 | 85 | @pytest.fixture |
@@ -156,3 +170,49 @@ def test_attrs_dataset_conservative(sample_input_data, sample_grid_ds): |
156 | 170 | assert ds_regrid.attrs == sample_input_data.attrs |
157 | 171 | assert ds_regrid["d2m"].attrs == sample_input_data["d2m"].attrs |
158 | 172 | assert ds_regrid["longitude"].attrs == sample_input_data["longitude"].attrs |
| 173 | + |
| 174 | + |
| 175 | +class TestCoordOrder: |
| 176 | + @pytest.mark.parametrize("method", ["linear", "nearest", "cubic"]) |
| 177 | + @pytest.mark.parametrize("dataarray", [True, False]) |
| 178 | + def test_original(self, sample_input_data, sample_grid_ds, method, dataarray): |
| 179 | + input_data = sample_input_data["d2m"] if dataarray else sample_input_data |
| 180 | + regridder = getattr(input_data.regrid, method) |
| 181 | + ds_regrid = regridder(sample_grid_ds) |
| 182 | + assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"]) |
| 183 | + assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"]) |
| 184 | + |
| 185 | + @pytest.mark.parametrize("coord", ["latitude", "longitude"]) |
| 186 | + @pytest.mark.parametrize("method", ["linear", "nearest", "cubic"]) |
| 187 | + @pytest.mark.parametrize("dataarray", [True, False]) |
| 188 | + def test_reversed( |
| 189 | + self, sample_input_data, sample_grid_ds, method, coord, dataarray |
| 190 | + ): |
| 191 | + input_data = sample_input_data["d2m"] if dataarray else sample_input_data |
| 192 | + regridder = getattr(input_data.regrid, method) |
| 193 | + sample_grid_ds[coord] = list(reversed(sample_grid_ds[coord])) |
| 194 | + ds_regrid = regridder(sample_grid_ds) |
| 195 | + assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"]) |
| 196 | + assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"]) |
| 197 | + |
| 198 | + @pytest.mark.parametrize("dataarray", [True, False]) |
| 199 | + def test_conservative_original(self, sample_input_data, sample_grid_ds, dataarray): |
| 200 | + input_data = sample_input_data["d2m"] if dataarray else sample_input_data |
| 201 | + ds_regrid = input_data.regrid.conservative( |
| 202 | + sample_grid_ds, latitude_coord="latitude" |
| 203 | + ) |
| 204 | + assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"]) |
| 205 | + assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"]) |
| 206 | + |
| 207 | + @pytest.mark.parametrize("coord", ["latitude", "longitude"]) |
| 208 | + @pytest.mark.parametrize("dataarray", [True, False]) |
| 209 | + def test_conservative_reversed( |
| 210 | + self, sample_input_data, sample_grid_ds, coord, dataarray |
| 211 | + ): |
| 212 | + input_data = sample_input_data["d2m"] if dataarray else sample_input_data |
| 213 | + sample_grid_ds[coord] = list(reversed(sample_grid_ds[coord])) |
| 214 | + ds_regrid = input_data.regrid.conservative( |
| 215 | + sample_grid_ds, latitude_coord="latitude" |
| 216 | + ) |
| 217 | + assert_array_equal(ds_regrid["latitude"], sample_grid_ds["latitude"]) |
| 218 | + assert_array_equal(ds_regrid["longitude"], sample_grid_ds["longitude"]) |
0 commit comments