Skip to content
61 changes: 61 additions & 0 deletions cf_xarray/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,64 @@ def cf_to_points(ds: xr.Dataset):
j += n

return xr.DataArray(geoms, dims=node_count.dims, coords=node_count.coords)


def grid_bounds_to_polygons(ds: xr.Dataset) -> xr.DataArray:
"""
Converts a regular 2D lat/lon grid to a 2D array of shapely polygons.

Modified from https://notebooksharing.space/view/c6c1f3a7d0c260724115eaa2bf78f3738b275f7f633c1558639e7bbd75b31456.

Parameters
----------
ds : xr.Dataset
Dataset with "latitude" and "longitude" variables as well as their bounds variables.
1D "latitude" and "longitude" variables are supported. This function will automatically
broadcast them against each other.

Returns
-------
DataArray
DataArray with shapely polygon per grid cell.
"""
from shapely import Polygon

grid = ds.cf[["latitude", "longitude"]].load().reset_coords()
bounds = ds.cf.bounds

assert "latitude" in bounds
assert "longitude" in bounds
(lon_bounds,) = bounds["longitude"]
(lat_bounds,) = bounds["latitude"]

(points,) = xr.broadcast(grid)

bounds_dim = grid.cf.get_bounds_dim_name("latitude")
points = points.transpose(..., bounds_dim)
assert points.sizes[bounds_dim] == 2

lonbnd = points[lon_bounds].data
latbnd = points[lat_bounds].data

# geopandas needs this
expanded_lon = lonbnd[..., [0, 0, 1, 1]]
mask = expanded_lon[..., 0] >= 180
expanded_lon[mask, :] = expanded_lon[mask, :] - 360

# these magic numbers are OK :
# - 4 corners to a polygon, and
# - 2 from stacking lat, lon along the last axis
# flatten here to make iteration easier. It would be nice to avoid that.
# potentially with just np.vectorize. The polygon creation is the real slow bit.
# Shapely's MultiPolygon also iterates over a list in Python...
blocked = np.stack([expanded_lon, latbnd[..., [0, 1, 1, 0]]], axis=-1).reshape(
-1, 4, 2
)
polyarray = np.array(
[Polygon(blocked[i, ...]) for i in range(blocked.shape[0])], dtype="O"
)
newshape = latbnd.shape[:-1]
polyarray = polyarray.reshape(newshape)
boxes = points[lon_bounds][..., 0].copy(data=polyarray)

return boxes