Skip to content

Commit f2120b8

Browse files
Ensure most_common returns correct dtype (#53)
* Add test that should pass * Only call .where(covered) if needed. Improve var naming * Expose fill_value to regrid.most_common/least_common * Expose fill_value in regrid.stat as well * Ruff changed formatting * Warn the user if data is cast to float * Undo formatting change
1 parent 0bd64aa commit f2120b8

File tree

4 files changed

+48
-16
lines changed

4 files changed

+48
-16
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ ignore = [
135135
"S105", "S106", "S107",
136136
# Ignore complexity
137137
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
138+
# Ignore magic values (false positives)
139+
"PLR2004",
138140
# Causes conflicts with formatter
139141
"ISC001",
140142
]

src/xarray_regrid/methods/_shared.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utility functions shared between methods."""
22

3+
import warnings
34
from collections.abc import Hashable
45
from typing import Any, overload
56

@@ -53,13 +54,21 @@ def restore_properties(
5354
result[coord].attrs = target_ds[coord].attrs
5455

5556
# Replace zeros outside of original data grid with NaNs
56-
uncovered_target_grid = (target_ds[coord] <= original_data[coord].max()) & (
57+
covered = (target_ds[coord] <= original_data[coord].max()) & (
5758
target_ds[coord] >= original_data[coord].min()
5859
)
59-
if fill_value is None:
60-
result = result.where(uncovered_target_grid)
61-
else:
62-
result = result.where(uncovered_target_grid, fill_value)
60+
61+
if (~covered).any():
62+
if fill_value is None:
63+
if np.issubdtype(result.dtype, np.integer):
64+
msg = (
65+
"No fill_value is provided; data will be cast to "
66+
"floating point dtype to be able to use NaN for missing values."
67+
)
68+
warnings.warn(msg, stacklevel=1)
69+
result = result.where(covered)
70+
else:
71+
result = result.where(covered, fill_value)
6372

6473
return result.transpose(*original_data.dims)
6574

src/xarray_regrid/regrid.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Hashable
2-
from typing import overload
2+
from typing import Any, overload
33

44
import numpy as np
55
import xarray as xr
@@ -14,11 +14,12 @@ class Regridder:
1414
"""Regridding xarray datasets and dataarrays.
1515
1616
Available methods:
17-
linear: linear, bilinear, or higher dimensional linear interpolation.
18-
nearest: nearest-neighbor regridding.
19-
cubic: cubic spline regridding.
20-
conservative: conservative regridding.
17+
linear: linear, bilinear, or higher dimensional linear interpolation
18+
nearest: nearest-neighbor regridding
19+
cubic: cubic spline regridding
20+
conservative: conservative regridding
2121
most_common: most common value regridder
22+
stat: area statistics regridder
2223
"""
2324

2425
def __init__(self, xarray_obj: xr.DataArray | xr.Dataset):
@@ -134,6 +135,7 @@ def most_common(
134135
ds_target_grid: xr.Dataset,
135136
values: np.ndarray,
136137
time_dim: str | None = "time",
138+
fill_value: None | Any = None,
137139
) -> xr.DataArray:
138140
"""Regrid by taking the most common value within the new grid cells.
139141
@@ -151,6 +153,9 @@ def most_common(
151153
contains the values 0, 2 and 4.
152154
time_dim: Name of the time dimension. Defaults to "time". Use `None` to
153155
force regridding over the time dimension.
156+
fill_value: What value to fill uncovered parts of the target grid.
157+
By default this will be NaN, and integer type data will be cast to
158+
float to accomodate this.
154159
155160
Returns:
156161
Regridded data.
@@ -173,6 +178,7 @@ def most_common(
173178
ds_target_grid,
174179
values,
175180
time_dim,
181+
fill_value,
176182
anti_mode=False,
177183
)
178184

@@ -181,6 +187,7 @@ def least_common(
181187
ds_target_grid: xr.Dataset,
182188
values: np.ndarray,
183189
time_dim: str | None = "time",
190+
fill_value: None | Any = None,
184191
) -> xr.DataArray:
185192
"""Regrid by taking the least common value within the new grid cells.
186193
@@ -198,6 +205,9 @@ def least_common(
198205
contains the values 0, 2 and 4.
199206
time_dim: Name of the time dimension. Defaults to "time". Use `None` to
200207
force regridding over the time dimension.
208+
fill_value: What value to fill uncovered parts of the target grid.
209+
By default this will be NaN, and integer type data will be cast to
210+
float to accomodate this.
201211
202212
Returns:
203213
Regridded data.
@@ -220,6 +230,7 @@ def least_common(
220230
ds_target_grid,
221231
values,
222232
time_dim,
233+
fill_value,
223234
anti_mode=True,
224235
)
225236

@@ -229,6 +240,7 @@ def stat(
229240
method: str,
230241
time_dim: str | None = "time",
231242
skipna: bool = False,
243+
fill_value: None | Any = None,
232244
) -> xr.DataArray | xr.Dataset:
233245
"""Upsampling of data using statistical methods (e.g. the mean or variance).
234246
@@ -243,6 +255,9 @@ def stat(
243255
time_dim: Name of the time dimension. Defaults to "time". Use `None` to
244256
force regridding over the time dimension.
245257
skipna: If NaN values should be ignored.
258+
fill_value: What value to fill uncovered parts of the target grid.
259+
By default this will be NaN, and integer type data will be cast to
260+
float to accomodate this.
246261
247262
Returns:
248263
xarray.dataset with regridded land cover categorical data.
@@ -251,7 +266,7 @@ def stat(
251266
ds_formatted = format_for_regrid(self._obj, ds_target_grid, stats=True)
252267

253268
return flox_reduce.statistic_reduce(
254-
ds_formatted, ds_target_grid, time_dim, method, skipna
269+
ds_formatted, ds_target_grid, time_dim, method, skipna, fill_value
255270
)
256271

257272

tests/test_reduce.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,22 @@ def test_most_common(dummy_lc_data, dummy_target_grid):
9393
[0, 0, 0, 0, 0, 0],
9494
[0, 0, 0, 0, 0, 0],
9595
[3, 3, 0, 0, 0, 1],
96-
]
96+
],
97+
dtype="uint8",
98+
)
99+
input_data_int = dummy_lc_data["lc"].astype("uint8")
100+
101+
regrid_result = input_data_int.regrid.most_common(
102+
dummy_target_grid,
103+
values=EXP_LABELS,
97104
)
98105
xr.testing.assert_equal(
99-
dummy_lc_data["lc"].regrid.most_common(
100-
dummy_target_grid,
101-
values=EXP_LABELS,
102-
),
106+
regrid_result,
103107
make_expected_ds(expected_data)["lc"],
104108
)
105109

110+
assert regrid_result.dtype == input_data_int.dtype
111+
106112

107113
def test_least_common(dummy_lc_data, dummy_target_grid):
108114
# Currently just test if the method runs: code is 99% the same as most_common

0 commit comments

Comments
 (0)