Skip to content

Commit f8f23fe

Browse files
committed
Fix _get_ordered_vertices() to preserve leading dims
- Also handle strictly monotonic bounds
1 parent 08f4c09 commit f8f23fe

File tree

2 files changed

+154
-32
lines changed

2 files changed

+154
-32
lines changed

cf_xarray/helpers.py

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -216,49 +216,126 @@ def _bounds_helper(values, n_core_dims, nbounds, order):
216216
return vertex_vals
217217

218218

219-
def _get_ordered_vertices(bounds: xr.DataArray) -> np.ndarray:
220-
"""Extracts a sorted 1D array of unique vertex values from a bounds DataArray.
221-
222-
This function takes a DataArray (or array-like) containing bounds information,
223-
typically as pairs of values along the last dimension. It flattens the
224-
bounds into pairs, extracts all unique vertex values, and returns them in
225-
sorted order. The sorting order (ascending or descending) is determined by
226-
inspecting the direction of the first non-equal bounds pair.
219+
def _get_ordered_vertices(bounds: np.ndarray) -> np.ndarray:
220+
"""
221+
Convert a bounds array of shape (..., N, 2) or (N, 2) into a 1D array of
222+
vertices.
223+
224+
This function reconstructs the vertices from a bounds array, handling both
225+
strictly monotonic and non-strictly monotonic bounds. For strictly monotonic
226+
bounds, it concatenates the left endpoints and the last right endpoint. For
227+
non-strictly monotonic bounds (i.e., bounds that are consistently ascending or
228+
descending within their intervals, but not strictly so), it uses the minimum
229+
of each interval as the lower endpoint and the maximum of the last interval as
230+
the final vertex, then sorts the vertices in ascending or descending order to
231+
match the direction of the bounds.
232+
233+
- Handles both ascending and descending bounds.
234+
- Does not require bounds to be strictly monotonic.
235+
- Preserves repeated coordinates if present.
236+
- Output shape is (..., N+1) or (N+1,).
227237
228238
Parameters
229239
----------
230-
bounds : xr.DataArray
231-
A DataArray containing bounds information, typically with shape (..., 2).
240+
bounds : np.ndarray
241+
An array containing bounds information, typically with shape (N, 2)
242+
or (..., N, 2).
232243
233244
Returns
234245
-------
235246
np.ndarray
236-
A 1D NumPy array of sorted unique vertex values extracted from the
237-
bounds.
247+
An array of vertices with shape (..., N+1) or (N+1,).
238248
"""
239-
# Convert to array if needed
240-
arr = bounds.values if isinstance(bounds, xr.DataArray) else bounds
241-
arr = np.asarray(arr)
242-
243-
# Flatten to (N, 2) pairs and get all unique values.
244-
pairs = arr.reshape(-1, 2)
245-
vertices = np.unique(pairs)
246-
247-
# Determine order: find the first pair with different values
248-
ascending = True
249-
for left, right in pairs:
250-
if left != right:
251-
ascending = right > left
252-
break
253-
254-
# Sort vertices in ascending or descending order as needed.
255-
vertices = np.sort(vertices)
256-
if not ascending:
257-
vertices = vertices[::-1]
249+
if _is_bounds_strictly_monotonic(bounds):
250+
# Example: [[51.0, 50.5], [50.5, 50.0]]
251+
# Example Result: [51.0, 50.5, 50.0]
252+
vertices = np.concatenate((bounds[..., :, 0], bounds[..., -1:, 1]), axis=-1)
253+
else:
254+
# Example with bounds (descending) [[50.5, 50.0], [51.0, 50.5]]
255+
# Get the lower endpoints of each bounds interval
256+
# Example Result: [50, 50.5]
257+
lower_endpoints = np.minimum(bounds[..., :, 0], bounds[..., :, 1])
258+
259+
# Get the upper endpoint of the last interval.
260+
# Example Result: 51.0
261+
last_upper_endpoint = np.maximum(bounds[..., -1, 0], bounds[..., -1, 1])
262+
263+
# Concatenate lower endpoints and the last upper endpoint.
264+
# Example Result: [50.0, 50.5, 51.0]
265+
vertices = np.concatenate(
266+
[lower_endpoints, np.expand_dims(last_upper_endpoint, axis=-1)], axis=-1
267+
)
268+
269+
# Sort vertices based on the direction of the bounds
270+
# Example Result: [51.0, 50.5, 50.0]
271+
ascending = is_bounds_ascending(bounds)
272+
if ascending:
273+
vertices = np.sort(vertices, axis=-1)
274+
else:
275+
vertices = np.sort(vertices, axis=-1)[..., ::-1]
258276

259277
return vertices
260278

261279

280+
def _is_bounds_strictly_monotonic(arr: np.ndarray) -> bool:
281+
"""
282+
Check if the second-to-last axis of a numpy array is strictly monotonic.
283+
284+
This function checks if the second-to-last axis of the input array is
285+
strictly monotonic (either strictly increasing or strictly decreasing)
286+
for arrays of shape (..., N, 2), preserving leading dimensions.
287+
288+
Parameters
289+
----------
290+
arr : np.ndarray
291+
Numpy array to check, typically with shape (..., N, 2).
292+
293+
Returns
294+
-------
295+
bool
296+
True if the array is strictly monotonic along the second-to-last axis
297+
for all leading dimensions, False otherwise.
298+
299+
Examples
300+
--------
301+
>>> bounds = np.array([
302+
... [76.25, 73.75],
303+
... [73.75, 71.25],
304+
... [71.25, 68.75],
305+
... [68.75, 66.25],
306+
... [66.25, 63.75]
307+
... ], dtype=np.float32)
308+
>>> _is_bounds_strictly_monotonic(bounds)
309+
True
310+
"""
311+
diffs = np.diff(arr, axis=-2)
312+
strictly_increasing = np.all(diffs > 0, axis=-2)
313+
strictly_decreasing = np.all(diffs < 0, axis=-2)
314+
315+
return np.all(strictly_increasing | strictly_decreasing)
316+
317+
318+
def is_bounds_ascending(bounds: np.ndarray) -> bool:
319+
"""Check if bounds are in ascending order.
320+
321+
Parameters
322+
----------
323+
bounds : np.ndarray
324+
An array containing bounds information, typically with shape (N, 2)
325+
or (..., N, 2).
326+
327+
Returns
328+
-------
329+
bool
330+
True if bounds are in ascending order, False if they are in descending
331+
order.
332+
"""
333+
lower = bounds[..., :, 0]
334+
upper = bounds[..., :, 1]
335+
336+
return np.all(lower < upper)
337+
338+
262339
def vertices_to_bounds(
263340
vertices: DataArray, out_dims: Sequence[str] = ("bounds", "x", "y")
264341
) -> DataArray:

cf_xarray/tests/test_helpers.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import xarray as xr
12
from numpy.testing import assert_array_equal
23
from xarray.testing import assert_equal
34

@@ -12,7 +13,7 @@
1213

1314

1415
def test_bounds_to_vertices() -> None:
15-
# 1D case
16+
# 1D case (stricly monotonic, descending bounds)
1617
ds = airds.cf.add_bounds(["lon", "lat", "time"])
1718
lat_c = cfxr.bounds_to_vertices(ds.lat_bounds, bounds_dim="bounds")
1819
assert_array_equal(ds.lat.values + 1.25, lat_c.values[:-1])
@@ -34,6 +35,50 @@ def test_bounds_to_vertices() -> None:
3435
lon_no = cfxr.bounds_to_vertices(rotds.lon_bounds, bounds_dim="bounds", order=None)
3536
assert_equal(lon_no, lon_ccw)
3637

38+
# 2D case (descending)
39+
bounds_2d_desc = xr.DataArray(
40+
[[50.5, 50.0], [51.0, 50.5], [51.0, 50.5], [52.0, 51.5], [52.5, 52.0]],
41+
dims=("lat", "bounds"),
42+
)
43+
expected_vertices_2d_desc = xr.DataArray(
44+
[52.5, 52.0, 51.5, 50.5, 50.5, 50.0],
45+
dims=["lat_vertices"],
46+
)
47+
vertices_2d_desc = cfxr.bounds_to_vertices(bounds_2d_desc, bounds_dim="bounds")
48+
assert_equal(expected_vertices_2d_desc, vertices_2d_desc)
49+
50+
# 3D case (ascending, "extra" non-core dim should be preserved)
51+
bounds_3d = xr.DataArray(
52+
[
53+
[
54+
[50.0, 50.5],
55+
[50.5, 51.0],
56+
[51.0, 51.5],
57+
[51.5, 52.0],
58+
[52.0, 52.5],
59+
],
60+
[
61+
[60.0, 60.5],
62+
[60.5, 61.0],
63+
[61.0, 61.5],
64+
[61.5, 62.0],
65+
[62.0, 62.5],
66+
],
67+
],
68+
dims=("extra", "lat", "bounds"),
69+
)
70+
expected_vertices_3d = xr.DataArray(
71+
[
72+
[50.0, 50.5, 51.0, 51.5, 52.0, 52.5],
73+
[60.0, 60.5, 61.0, 61.5, 62.0, 62.5],
74+
],
75+
dims=("extra", "lat_vertices"),
76+
)
77+
vertices_3d = cfxr.bounds_to_vertices(
78+
bounds_3d, bounds_dim="bounds", core_dims=["lat"]
79+
)
80+
assert_equal(vertices_3d, expected_vertices_3d)
81+
3782
# Transposing the array changes the bounds direction
3883
ds = mollwds.transpose("x", "y", "x_vertices", "y_vertices", "bounds")
3984
lon_cw = cfxr.bounds_to_vertices(

0 commit comments

Comments
 (0)