Skip to content

Commit 05223b3

Browse files
committed
Add simple test for most_common. Fix typing issues
1 parent 51ee64c commit 05223b3

File tree

4 files changed

+68
-17
lines changed

4 files changed

+68
-17
lines changed

src/xarray_regrid/most_common.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Implementation of the "most common value" regridding method."""
22

33
from itertools import product
4-
from typing import overload
4+
from typing import Any, overload
55

66
import flox.xarray
77
import numpy as np
8-
import numpy_groupies as npg
8+
import numpy_groupies as npg # type: ignore
99
import pandas as pd
1010
import xarray as xr
1111
from flox import Aggregation
@@ -78,7 +78,7 @@ def most_common_wrapper(
7878

7979

8080
def split_combine_most_common(
81-
data: xr.Dataset, target_ds: xr.Dataset, time_dim: str, max_mem: int = 1e9
81+
data: xr.Dataset, target_ds: xr.Dataset, time_dim: str, max_mem: int = int(1e9)
8282
) -> xr.Dataset:
8383
"""Use a split-combine strategy to reduce the memory use of the most_common regrid.
8484
@@ -173,7 +173,7 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
173173

174174
most_common = Aggregation(
175175
name="most_common",
176-
numpy=_custom_grouped_reduction,
176+
numpy=_custom_grouped_reduction, # type: ignore
177177
chunk=None,
178178
combine=None,
179179
)
@@ -208,14 +208,7 @@ def _most_common_label(neighbors: np.ndarray) -> np.ndarray:
208208
then the first label in the list will be picked.
209209
"""
210210
unique_labels, counts = np.unique(neighbors, return_counts=True)
211-
return unique_labels[np.argmax(counts)]
212-
213-
214-
def most_common_chunked(multi_values: np.ndarray, multi_counts: np.ndarray):
215-
all_values, index = np.unique(multi_values, return_inverse=True)
216-
all_counts = np.zeros(all_values.size, np.int64)
217-
np.add.at(all_counts, index, multi_counts.ravel()) # inplace
218-
return all_values[all_counts.argmax()]
211+
return unique_labels[np.argmax(counts)] # type: ignore
219212

220213

221214
def _custom_grouped_reduction(
@@ -224,8 +217,8 @@ def _custom_grouped_reduction(
224217
*,
225218
axis: int = -1,
226219
size: int | None = None,
227-
fill_value=None,
228-
dtype=None,
220+
fill_value: Any = None,
221+
dtype: Any = None,
229222
) -> np.ndarray:
230223
"""Custom grouped reduction for flox.Aggregation to get most common label.
231224
@@ -242,7 +235,7 @@ def _custom_grouped_reduction(
242235
Returns:
243236
np.ndarray with array.shape[-1] == size, containing a single value per group
244237
"""
245-
return npg.aggregate_numpy.aggregate(
238+
agg: np.ndarray = npg.aggregate_numpy.aggregate(
246239
group_idx,
247240
array,
248241
func=_most_common_label,
@@ -251,3 +244,4 @@ def _custom_grouped_reduction(
251244
fill_value=fill_value,
252245
dtype=dtype,
253246
)
247+
return agg

src/xarray_regrid/regrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def most_common(
9595
self,
9696
ds_target_grid: xr.Dataset,
9797
time_dim: str = "time",
98-
max_mem: int = 1e9
98+
max_mem: int = int(1e9),
9999
) -> xr.DataArray | xr.Dataset:
100100
"""Regrid by taking the most common value within the new grid cells.
101101

src/xarray_regrid/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Hashable
12
from dataclasses import dataclass
23

34
import numpy as np
@@ -156,7 +157,7 @@ def common_coords(
156157
data1: xr.DataArray | xr.Dataset,
157158
data2: xr.DataArray | xr.Dataset,
158159
remove_coord: str | None = None,
159-
) -> set[str]:
160+
) -> set[Hashable]:
160161
"""Return a set of coords which two dataset/arrays have in common."""
161162
coords = set(data1.coords).intersection(set(data2.coords))
162163
if remove_coord in coords:

tests/test_most_common.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
from xarray_regrid import Grid, create_regridding_dataset
6+
7+
8+
@pytest.fixture
9+
def dummy_lc_data():
10+
np.random.seed(0)
11+
data = np.random.randint(0, 3, size=(11, 11))
12+
lat_coords = np.linspace(0, 40, num=11)
13+
lon_coords = np.linspace(0, 40, num=11)
14+
15+
return xr.Dataset(
16+
data_vars={
17+
"lc": (["latitude", "longitude"], data),
18+
},
19+
coords={
20+
"longitude": (["longitude"], lon_coords),
21+
"latitude": (["latitude"], lat_coords),
22+
},
23+
)
24+
25+
26+
@pytest.fixture
27+
def dummy_target_grid():
28+
new_grid = Grid(
29+
north=40,
30+
east=40,
31+
south=0,
32+
west=0,
33+
resolution_lat=8,
34+
resolution_lon=8,
35+
)
36+
return create_regridding_dataset(new_grid)
37+
38+
39+
def test_most_common(dummy_lc_data, dummy_target_grid):
40+
expected = np.array(
41+
[
42+
[0, 0, 1, 0, 0, 0],
43+
[0, 1, 0, 1, 1, 0],
44+
[0, 1, 2, 0, 0, 2],
45+
[1, 0, 1, 2, 0, 0],
46+
[0, 0, 0, 1, 0, 1],
47+
[1, 2, 2, 0, 2, 2],
48+
]
49+
)
50+
51+
np.testing.assert_array_equal(
52+
dummy_lc_data.regrid.most_common(
53+
dummy_target_grid,
54+
)["lc"].values,
55+
expected,
56+
)

0 commit comments

Comments
 (0)