Skip to content

Commit c204818

Browse files
authored
ENH: allowing masking nodata in zonal_stats (#123)
1 parent 801eb7f commit c204818

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

xvec/accessor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,7 @@ def zonal_stats(
985985
method: str = "rasterize",
986986
all_touched: bool = False,
987987
n_jobs: int = -1,
988+
nodata: Any = None,
988989
**kwargs: dict[str, Any],
989990
) -> xr.DataArray | xr.Dataset:
990991
"""Extract the values from a dataset indexed by a set of geometries
@@ -1062,6 +1063,9 @@ def zonal_stats(
10621063
Number of parallel threads to use. It is recommended to set this to the
10631064
number of physical cores of the CPU. ``-1`` uses all available cores.
10641065
Applies only if ``method="iterate"``.
1066+
nodata : Any
1067+
Value representing missing data. If not specified, the value is included in
1068+
the aggregation.
10651069
**kwargs : optional
10661070
Keyword arguments to be passed to the aggregation function
10671071
(e.g., ``Dataset.quantile(**kwargs)``).
@@ -1152,6 +1156,7 @@ def zonal_stats(
11521156
y_coords=y_coords,
11531157
stats=stats,
11541158
all_touched=all_touched,
1159+
nodata=nodata,
11551160
)
11561161

11571162
if method == "rasterize":
@@ -1163,6 +1168,7 @@ def zonal_stats(
11631168
stats=stats,
11641169
name=name,
11651170
all_touched=all_touched,
1171+
nodata=nodata,
11661172
**kwargs,
11671173
)
11681174
elif method == "iterate":
@@ -1175,6 +1181,7 @@ def zonal_stats(
11751181
name=name,
11761182
all_touched=all_touched,
11771183
n_jobs=n_jobs,
1184+
nodata=nodata,
11781185
**kwargs,
11791186
)
11801187
elif method == "exactextract":
@@ -1185,6 +1192,7 @@ def zonal_stats(
11851192
y_coords=y_coords,
11861193
stats=stats,
11871194
name=name,
1195+
nodata=nodata,
11881196
**kwargs,
11891197
)
11901198
else:

xvec/tests/test_zonal_stats.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,19 @@ def test_exactextract_strategy():
456456
method="exactextract",
457457
strategy="invalid_strategy",
458458
)
459+
460+
461+
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
462+
def test_nodata(method):
463+
ds = xr.tutorial.open_dataset("eraint_uvz")
464+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
465+
466+
arr = ds.z.where(ds.z > ds.z.mean(), -9999)
467+
unmasked = arr.xvec.zonal_stats(
468+
world.geometry, "longitude", "latitude", method=method
469+
)
470+
masked = arr.xvec.zonal_stats(
471+
world.geometry, "longitude", "latitude", method=method, nodata=-9999
472+
)
473+
474+
assert unmasked.mean() < masked.mean()

xvec/zonal.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _zonal_stats_rasterize(
3333
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
3434
name: str = "geometry",
3535
all_touched: bool = False,
36+
nodata: Any = None,
3637
**kwargs,
3738
) -> xr.DataArray | xr.Dataset:
3839
try:
@@ -70,6 +71,11 @@ def _zonal_stats_rasterize(
7071
unique.remove(length)
7172

7273
obj = acc._obj.copy()
74+
75+
# mask out nodata - note that this casts whole array to float
76+
if nodata is not None:
77+
obj = obj.where(obj != nodata)
78+
7379
if isinstance(obj, xr.Dataset):
7480
obj = obj.assign_coords(
7581
__labels__=xr.DataArray(labels, dims=(y_coords, x_coords))
@@ -124,6 +130,7 @@ def _zonal_stats_iterative(
124130
name: str = "geometry",
125131
all_touched: bool = False,
126132
n_jobs: int = -1,
133+
nodata: Any = None,
127134
**kwargs: dict[str, Any],
128135
) -> xr.DataArray | xr.Dataset:
129136
"""Extract the values from a dataset indexed by a set of geometries
@@ -158,6 +165,9 @@ def _zonal_stats_iterative(
158165
n_jobs : int, optional
159166
Number of parallel threads to use.
160167
It is recommended to set this to the number of physical cores in the CPU.
168+
nodata : Any
169+
Value representing missing data. If not specified, the value is included in
170+
the aggregation.
161171
**kwargs : optional
162172
Keyword arguments to be passed to the aggregation function
163173
(as ``Dataset.mean(**kwargs)``).
@@ -198,6 +208,7 @@ def _zonal_stats_iterative(
198208
y_coords,
199209
stats=stats,
200210
all_touched=all_touched,
211+
nodata=nodata,
201212
**kwargs,
202213
)
203214
for geom in geometry
@@ -224,6 +235,7 @@ def _agg_geom(
224235
y_coords: str | None = None,
225236
stats: str | Callable | Iterable[str | Callable | tuple] = "mean",
226237
all_touched: bool = False,
238+
nodata: Any = None,
227239
**kwargs,
228240
):
229241
"""Aggregate the values from a dataset over a polygon geometry.
@@ -250,6 +262,9 @@ def _agg_geom(
250262
If True, all pixels touched by geometries will be considered. If False, only
251263
pixels whose center is within the polygon or that are selected by
252264
Bresenham’s line algorithm will be considered.
265+
nodata : Any
266+
Value representing missing data. If not specified, the value is included in
267+
the aggregation.
253268
254269
Returns
255270
-------
@@ -270,6 +285,8 @@ def _agg_geom(
270285
all_touched=all_touched,
271286
)
272287
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
288+
if nodata is not None:
289+
masked = masked.where(masked != nodata)
273290
if pd.api.types.is_list_like(stats):
274291
agg = {}
275292
for stat in stats: # type: ignore
@@ -309,6 +326,7 @@ def _zonal_stats_exactextract(
309326
y_coords: Hashable,
310327
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
311328
name: str = "geometry",
329+
nodata: Any = None,
312330
**kwargs,
313331
) -> xr.DataArray | xr.Dataset:
314332
"""Extract the values from a dataset indexed by a set of geometries
@@ -334,6 +352,9 @@ def _zonal_stats_exactextract(
334352
``"quantile(q=0.20)"``)
335353
name : str, optional
336354
Name of the dimension that will hold the ``geometry``, by default "geometry"
355+
nodata : Any
356+
Value representing missing data. If not specified, the value is included in
357+
the aggregation.
337358
338359
Returns
339360
-------
@@ -372,6 +393,7 @@ def _zonal_stats_exactextract(
372393
stats,
373394
name,
374395
original_is_ds,
396+
nodata=nodata,
375397
**kwargs,
376398
)
377399
i = 0
@@ -410,6 +432,7 @@ def _zonal_stats_exactextract(
410432
stats,
411433
name,
412434
original_is_ds,
435+
nodata=nodata,
413436
**kwargs,
414437
)
415438
# Unstack the result
@@ -447,6 +470,7 @@ def _agg_exactextract(
447470
name: str = "geometry",
448471
original_is_ds: bool = False,
449472
strategy: str = "feature-sequential",
473+
nodata: Any = None,
450474
):
451475
"""Extract the values from a dataset indexed by a set of geometries
452476
@@ -476,6 +500,9 @@ def _agg_exactextract(
476500
If True, all pixels touched by geometries will be considered. If False, only
477501
pixels whose center is within the polygon or that are selected by
478502
Bresenham’s line algorithm will be considered.
503+
nodata : Any
504+
Value representing missing data. If not specified, the value is included in
505+
the aggregation.
479506
strategy : str, optional
480507
The strategy to use for the extraction, by default "feature-sequential"
481508
Use either "feature-sequential" and "raster-sequential".
@@ -511,6 +538,10 @@ def _agg_exactextract(
511538
# Check the order of dimensions
512539
data = data.transpose("location", y_coords, x_coords)
513540

541+
# mask nodata
542+
if nodata is not None:
543+
data = data.where(data != nodata)
544+
514545
# Aggregation result
515546
gdf = gpd.GeoDataFrame(geometry=geometry, crs=crs)
516547
results = exactextract.exact_extract(
@@ -537,7 +568,16 @@ def _agg_exactextract(
537568

538569

539570
def _get_mean(
540-
geom_arr, obj, x_coords, y_coords, transform, all_touched, stats, dims, **kwargs
571+
geom_arr,
572+
obj,
573+
x_coords,
574+
y_coords,
575+
transform,
576+
all_touched,
577+
stats,
578+
dims,
579+
nodata,
580+
**kwargs,
541581
):
542582
from rasterio import features
543583

@@ -552,6 +592,10 @@ def _get_mean(
552592
all_touched=all_touched,
553593
)
554594
masked = obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
595+
596+
if nodata is not None:
597+
masked = masked.where(masked != nodata)
598+
555599
if pd.api.types.is_list_like(stats):
556600
agg = {}
557601
for stat in stats: # type: ignore
@@ -589,6 +633,7 @@ def _variable_zonal(
589633
y_coords: Hashable,
590634
stats="mean",
591635
all_touched: bool = False,
636+
nodata: Any = None,
592637
):
593638
transform = acc._obj.rio.transform()
594639
dims = variable_geometry.dims
@@ -597,7 +642,7 @@ def _variable_zonal(
597642

598643
for x in stacked:
599644
m = _get_mean(
600-
x, acc._obj, x_coords, y_coords, transform, all_touched, stats, dims
645+
x, acc._obj, x_coords, y_coords, transform, all_touched, stats, dims, nodata
601646
)
602647
m.name = "statistics"
603648
r.append(m)

0 commit comments

Comments
 (0)