Skip to content

Commit 0b84419

Browse files
dcherianclaude
andauthored
proptest: Indexing rules match PandasIndex (#57)
Co-authored-by: Claude <[email protected]>
1 parent e144693 commit 0b84419

File tree

4 files changed

+367
-5
lines changed

4 files changed

+367
-5
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ dependencies = [
3030
"pandas>=2",
3131
"numpy>=2",
3232
"xproj>=0.2.0",
33-
"xarray>=2025",
33+
"xarray @ git+https://github.com/dcherian/xarray.git@fix-coord-transform-indexing",
3434
]
3535
dynamic=["version"]
3636

@@ -75,6 +75,9 @@ docs = [
7575
[tool.hatch]
7676
version.source = "vcs"
7777

78+
[tool.hatch.metadata]
79+
allow-direct-references = true
80+
7881
[tool.hatch.build]
7982
hooks.vcs.version-file = "src/rasterix/_version.py"
8083

src/rasterix/raster_index.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,20 @@ def isel( # type: ignore[override]
310310
def sel(self, labels, method=None, tolerance=None):
311311
coord_name = self.axis_transform.coord_name
312312
label = labels[coord_name]
313-
313+
transform = self.axis_transform
314314
if isinstance(label, slice):
315-
if label.start is None:
316-
label = slice(0, label.stop, label.step)
315+
label = slice(
316+
label.start or transform.forward({coord_name: 0})[coord_name],
317+
label.stop or transform.forward({coord_name: transform.size})[coord_name],
318+
label.step,
319+
)
317320
if label.step is None:
318321
# continuous interval slice indexing (preserves the index)
319322
pos = self.transform.reverse({coord_name: np.array([label.start, label.stop])})
320323
# np.round rounds to even, this way we round upwards
321324
pos = np.floor(pos[self.dim] + 0.5).astype("int")
322325
new_start = max(pos[0], 0)
323-
new_stop = min(pos[1], self.axis_transform.size)
326+
new_stop = min(pos[1] + 1, self.axis_transform.size)
324327
return IndexSelResult({self.dim: slice(new_start, new_stop)})
325328
else:
326329
# otherwise convert to basic (array) indexing

src/rasterix/strategies.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""Hypothesis strategies for generating label-based indexers."""
2+
3+
from collections.abc import Hashable
4+
from typing import Any
5+
6+
import numpy as np
7+
import pandas as pd
8+
import xarray as xr
9+
from hypothesis import note
10+
from hypothesis import strategies as st
11+
from xarray.core.indexes import Indexes
12+
from xarray.testing.strategies import (
13+
basic_indexers,
14+
outer_array_indexers,
15+
vectorized_indexers,
16+
)
17+
18+
19+
def pos_to_label_indexer(idx: pd.Index, idxr: int | slice | np.ndarray, *, use_scalar: bool = True) -> Any:
20+
"""Convert a positional indexer to a label-based indexer.
21+
22+
Parameters
23+
----------
24+
idx : pd.Index
25+
The pandas Index to use for label lookup.
26+
idxr : int | slice | np.ndarray
27+
The positional indexer (integer, slice, or array of integers).
28+
use_scalar : bool, optional
29+
If True, attempt to convert scalar values to Python scalars. Default is True.
30+
31+
Returns
32+
-------
33+
Any
34+
The label-based indexer (scalar, slice, or array of labels).
35+
"""
36+
if isinstance(idxr, slice):
37+
return slice(
38+
None if idxr.start is None else idx[idxr.start],
39+
# FIXME: This will never go past the label range
40+
None if idxr.stop is None else idx[min(idxr.stop, idx.size - 1)],
41+
)
42+
elif isinstance(idxr, np.ndarray):
43+
# Convert array of position indices to array of label values
44+
return idx[idxr].values
45+
else:
46+
val = idx[idxr]
47+
if use_scalar:
48+
try:
49+
# pass python scalars occasionally
50+
val = val.item()
51+
except Exception:
52+
note(f"casting {val!r} to item() failed")
53+
pass
54+
return val
55+
56+
57+
@st.composite
58+
def basic_label_indexers(draw, /, *, indexes: Indexes) -> dict[Hashable, float | slice]:
59+
"""Generate label-based indexers by converting position indexers to labels.
60+
61+
This works in label space by using the coordinate Index values.
62+
63+
Parameters
64+
----------
65+
draw : callable
66+
The Hypothesis draw function (automatically provided by @st.composite).
67+
indexes : Indexes
68+
Dictionary mapping dimension names to their associated indexes
69+
70+
Returns
71+
-------
72+
dict[Hashable, float | slice]
73+
Label-based indexers as a dict with keys from sizes.keys().
74+
Values are either float (for scalar labels) or slice (for label ranges).
75+
"""
76+
idxs = indexes.get_unique()
77+
assert all(isinstance(idx, xr.indexes.PandasIndex) for idx in idxs)
78+
79+
# FIXME: this should be indexes.sizes!
80+
sizes = indexes.dims
81+
82+
pos_indexer = draw(basic_indexers(sizes=sizes))
83+
pdindexes = indexes.to_pandas_indexes()
84+
85+
label_indexer = {
86+
dim: pos_to_label_indexer(pdindexes[dim], idx, use_scalar=draw(st.booleans()))
87+
for dim, idx in pos_indexer.items()
88+
}
89+
return label_indexer
90+
91+
92+
@st.composite
93+
def outer_array_label_indexers(draw, /, *, indexes: Indexes) -> dict[Hashable, np.ndarray]:
94+
"""Generate label-based outer array indexers by converting position indexers to labels.
95+
96+
This works in label space by using the coordinate Index values.
97+
98+
Parameters
99+
----------
100+
draw : callable
101+
The Hypothesis draw function (automatically provided by @st.composite).
102+
indexes : Indexes
103+
Dictionary mapping dimension names to their associated indexes
104+
105+
Returns
106+
-------
107+
dict[Hashable, np.ndarray]
108+
Label-based indexers as a dict with keys from indexes.
109+
Values are numpy arrays of label values for each dimension.
110+
"""
111+
idxs = indexes.get_unique()
112+
assert all(isinstance(idx, xr.indexes.PandasIndex) for idx in idxs)
113+
114+
# FIXME: this should be indexes.sizes!
115+
sizes = indexes.dims
116+
117+
pos_indexer = draw(outer_array_indexers(sizes=sizes))
118+
pdindexes = indexes.to_pandas_indexes()
119+
120+
label_indexer = {
121+
dim: pos_to_label_indexer(pdindexes[dim], idx, use_scalar=False) for dim, idx in pos_indexer.items()
122+
}
123+
return label_indexer
124+
125+
126+
@st.composite
127+
def vectorized_label_indexers(draw, /, *, indexes: Indexes, **kwargs) -> dict[Hashable, xr.DataArray]:
128+
"""Generate label-based vectorized indexers by converting position indexers to labels.
129+
130+
This works in label space by using the coordinate Index values.
131+
132+
Parameters
133+
----------
134+
draw : callable
135+
The Hypothesis draw function (automatically provided by @st.composite).
136+
indexes : Indexes
137+
Dictionary mapping dimension names to their associated indexes
138+
**kwargs : dict
139+
Additional keyword arguments to pass to vectorized_indexers
140+
141+
Returns
142+
-------
143+
dict[Hashable, xr.DataArray]
144+
Label-based indexers as a dict with keys from indexes.
145+
Values are DataArrays of label values for each dimension.
146+
"""
147+
idxs = indexes.get_unique()
148+
assert all(isinstance(idx, xr.indexes.PandasIndex) for idx in idxs)
149+
150+
# FIXME: this should be indexes.sizes!
151+
sizes = indexes.dims
152+
153+
pos_indexer = draw(vectorized_indexers(sizes=sizes, **kwargs))
154+
pdindexes = indexes.to_pandas_indexes()
155+
156+
label_indexer = {}
157+
for dim, idx_array in pos_indexer.items():
158+
# Convert each position in the array to its corresponding label
159+
# Flatten, index, then reshape back to original shape
160+
flat_indices = idx_array.values.ravel()
161+
flat_labels = pdindexes[dim][flat_indices].values
162+
label_values = flat_labels.reshape(idx_array.shape)
163+
label_indexer[dim] = xr.DataArray(label_values, dims=idx_array.dims)
164+
165+
return label_indexer

0 commit comments

Comments
 (0)