Skip to content

Commit 4107a3a

Browse files
committed
Fix most_common test, improve docstring.
1 parent 05223b3 commit 4107a3a

File tree

4 files changed

+48
-20
lines changed

4 files changed

+48
-20
lines changed

src/xarray_regrid/most_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
153153
Returns:
154154
xarray.dataset with regridded land cover categorical data.
155155
"""
156+
dim_order = data.dims
156157
coords = utils.common_coords(data, target_ds, remove_coord=time_dim)
157158
bounds = tuple(
158159
_construct_intervals(target_ds[coord].to_numpy()) for coord in coords
@@ -189,7 +190,7 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
189190
for coord in coords:
190191
ds_regrid[coord] = target_ds[coord]
191192

192-
return ds_regrid
193+
return ds_regrid.transpose(*dim_order)
193194

194195

195196
def _construct_intervals(coord: np.ndarray) -> pd.IntervalIndex:

src/xarray_regrid/regrid.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def most_common(
9999
) -> xr.DataArray | xr.Dataset:
100100
"""Regrid by taking the most common value within the new grid cells.
101101
102-
To be used for regridding data to a much coarser resolution.
102+
To be used for regridding data to a much coarser resolution, not for regridding
103+
when the source and target grids are of a similar resolution.
104+
105+
Note that in the case of two unqiue values with the same count, the behaviour
106+
is not deterministic, and the resulting "most common" one will randomly be
107+
either of the two.
103108
104109
Args:
105110
ds_target_grid: Target grid dataset

src/xarray_regrid/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Hashable
21
from dataclasses import dataclass
32

43
import numpy as np
@@ -157,9 +156,9 @@ def common_coords(
157156
data1: xr.DataArray | xr.Dataset,
158157
data2: xr.DataArray | xr.Dataset,
159158
remove_coord: str | None = None,
160-
) -> set[Hashable]:
159+
) -> list[str]:
161160
"""Return a set of coords which two dataset/arrays have in common."""
162161
coords = set(data1.coords).intersection(set(data2.coords))
163162
if remove_coord in coords:
164163
coords.remove(remove_coord)
165-
return coords
164+
return sorted([str(coord) for coord in coords])

tests/test_most_common.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,27 @@
77

88
@pytest.fixture
99
def dummy_lc_data():
10-
np.random.seed(0)
11-
data = np.random.randint(0, 3, size=(11, 11))
10+
data = np.array(
11+
[
12+
[2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0],
13+
[2, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0],
14+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
15+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
16+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
17+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
18+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
19+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
20+
[3, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
21+
[3, 3, 3, 3, 0, 0, 0, 0, 1, 1, 1],
22+
[3, 3, 0, 3, 0, 0, 0, 0, 1, 1, 1],
23+
]
24+
)
1225
lat_coords = np.linspace(0, 40, num=11)
1326
lon_coords = np.linspace(0, 40, num=11)
1427

1528
return xr.Dataset(
1629
data_vars={
17-
"lc": (["latitude", "longitude"], data),
30+
"lc": (["longitude", "latitude"], data),
1831
},
1932
coords={
2033
"longitude": (["longitude"], lon_coords),
@@ -37,20 +50,30 @@ def dummy_target_grid():
3750

3851

3952
def test_most_common(dummy_lc_data, dummy_target_grid):
40-
expected = np.array(
53+
expected_data = np.array(
4154
[
42-
[0, 0, 1, 0, 0, 0],
43-
[0, 1, 0, 1, 1, 0],
44-
[0, 1, 2, 0, 0, 2],
45-
[1, 0, 1, 2, 0, 0],
46-
[0, 0, 0, 1, 0, 1],
47-
[1, 2, 2, 0, 2, 2],
55+
[2, 2, 0, 0, 0, 0],
56+
[0, 0, 0, 0, 0, 0],
57+
[0, 0, 0, 0, 0, 0],
58+
[0, 0, 0, 0, 0, 0],
59+
[0, 0, 0, 0, 0, 0],
60+
[3, 3, 0, 0, 0, 1],
4861
]
4962
)
5063

51-
np.testing.assert_array_equal(
52-
dummy_lc_data.regrid.most_common(
53-
dummy_target_grid,
54-
)["lc"].values,
55-
expected,
64+
lat_coords = np.linspace(0, 40, num=6)
65+
lon_coords = np.linspace(0, 40, num=6)
66+
67+
expected = xr.Dataset(
68+
data_vars={
69+
"lc": (["longitude", "latitude"], expected_data),
70+
},
71+
coords={
72+
"longitude": (["longitude"], lon_coords),
73+
"latitude": (["latitude"], lat_coords),
74+
},
75+
)
76+
xr.testing.assert_equal(
77+
dummy_lc_data.regrid.most_common(dummy_target_grid)["lc"],
78+
expected["lc"],
5679
)

0 commit comments

Comments
 (0)