Skip to content

Commit a88c521

Browse files
maawoomartinfleis
andauthored
BUG: Zonal stats with method='exactextract' not working on flat xarray.Dataarray (#132)
* fix #131 * fix typos * add new tests covering "flat" Dataset and Dataarray * lint --------- Co-authored-by: Martin Fleischmann <[email protected]>
1 parent 3ba5d0c commit a88c521

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

xvec/tests/test_zonal_stats.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,36 @@ def test_dataset(method):
147147
)
148148

149149

150+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
151+
def test_dataset_flat(method):
152+
ds = xr.tutorial.open_dataset("eraint_uvz").isel(month=0).isel(level=0)
153+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
154+
result = ds.xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method)
155+
156+
if method in ["exactextract", None]:
157+
xr.testing.assert_allclose(
158+
xr.Dataset(
159+
{
160+
"z": np.array(114857.63685302),
161+
"u": np.array(9.84182437),
162+
"v": np.array(-0.00330402),
163+
}
164+
),
165+
result.mean(),
166+
)
167+
else:
168+
xr.testing.assert_allclose(
169+
xr.Dataset(
170+
{
171+
"z": np.array(114302.08524294),
172+
"u": np.array(9.5196515),
173+
"v": np.array(0.29297792),
174+
}
175+
),
176+
result.drop_vars(["month", "level"]).mean(),
177+
)
178+
179+
150180
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
151181
def test_dataarray(method):
152182
ds = xr.tutorial.open_dataset("eraint_uvz")
@@ -163,6 +193,24 @@ def test_dataarray(method):
163193
assert result.mean() == pytest.approx(61367.76185577)
164194

165195

196+
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
197+
def test_dataarray_flat(method):
198+
ds = xr.tutorial.open_dataset("eraint_uvz")
199+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
200+
result = (
201+
ds.z.isel(month=0)
202+
.isel(level=0)
203+
.xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method)
204+
)
205+
206+
assert result.shape == (127,)
207+
assert result.dims == ("geometry",)
208+
if method in ["exactextract", None]:
209+
assert result.mean() == pytest.approx(114857.63685302)
210+
else:
211+
assert result.mean() == pytest.approx(114302.08524294)
212+
213+
166214
@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
167215
def test_stat(method):
168216
ds = xr.tutorial.open_dataset("eraint_uvz")

xvec/zonal.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,14 @@ def _agg_exactextract(
698698

699699
# Stack the other dimensions into one dimension called "location"
700700
arr_dims = tuple(dim for dim in acc._obj.dims if dim not in [x_coords, y_coords])
701-
data = acc._obj.stack(location=arr_dims)
702-
locs = data.location.size
701+
if arr_dims:
702+
# Stack non-spatial dimensions if they exist
703+
data = acc._obj.stack(location=arr_dims)
704+
locs = data.location.size
705+
else:
706+
# No additional dimensions to stack, create a dummy "location" dimension
707+
data = acc._obj.expand_dims("location")
708+
locs = 1
703709

704710
# Check the order of dimensions
705711
data = data.transpose("location", y_coords, x_coords)
@@ -713,23 +719,25 @@ def _agg_exactextract(
713719
results = exactextract.exact_extract(
714720
rast=data, vec=gdf, ops=stats, output="pandas", strategy=strategy
715721
)
716-
# Get all the dimensions execpt x_coords, y_coords, they will be used to stack the
722+
# Get all the dimensions except x_coords, y_coords, they will be used to stack the
717723
# dataarray later
718724
if original_is_ds is True:
719-
# Get the original dataset information to use for unstacking the resulte later
725+
# Get the original dataset information to use for unstacking the result later
720726
coords_info = {name: geometry}
721727
original_shape = [len(geometry)]
722-
for dim in arr_dims:
723-
original_shape.append(acc._obj[dim].size)
724-
if dim != "variable":
725-
coords_info[dim] = acc._obj[dim].values
728+
if arr_dims:
729+
for dim in arr_dims:
730+
original_shape.append(acc._obj[dim].size)
731+
if dim != "variable":
732+
coords_info[dim] = acc._obj[dim].values
726733
else:
727-
# Get the original dataarray information to use for unstacking the resulte later
734+
# Get the original dataarray information to use for unstacking the result later
728735
coords_info = {name: geometry}
729736
original_shape = [len(geometry)]
730-
for dim in arr_dims:
731-
original_shape.append(acc._obj[dim].size)
732-
coords_info[dim] = acc._obj[dim].values
737+
if arr_dims:
738+
for dim in arr_dims:
739+
original_shape.append(acc._obj[dim].size)
740+
coords_info[dim] = acc._obj[dim].values
733741
return results, original_shape, coords_info, locs
734742

735743

0 commit comments

Comments
 (0)