Skip to content

Commit 1a441c4

Browse files
committed
fix(torchgeo): match TorchGeo merge semantics for pixel-accurate interop
Root cause: TorchGeo uses rasterio.merge.merge(bounds=..., res=...) to place pixels on the sampler's query grid. Rasteret was using rasterio.warp.reproject, which differs by up to 1 pixel at extent boundaries (confirmed on NAIP). Changes: - Add src/rasteret/core/rio_semantics.py: centralises all rasterio merge-vs-warp semantic decisions in one module. merge_semantic_resample_single_source() writes the reader's crop to a MemoryFile and delegates placement to rasterio.merge.merge, matching TorchGeo's _merge_or_stack() contract exactly. Handles south-up rasters (e.g. AEF, transform.e > 0) by flipping to a north-up equivalent before merge, consistent with TorchGeo's WarpedVRT normalisation in _load_warp_file(). - Adapter (torchgeo.py): replace _warp_to_grid (rasterio.warp.reproject) with merge_semantic_resample_single_source for same-CRS sampling. Use MergeGrid instead of compute_dst_grid for query grid definition. Fix time_series path to filter by sampler time slice before spatial query (was ignoring time). - COG reader (cog.py): fix UnboundLocalError in bounds= path by computing intersecting_tiles for direct-bounds reads. - Test oracle (test_dataset_pixel_comparison.py): use rasterio.merge.merge as ground truth (was rasterio.mask.mask), with WarpedVRT for south-up, and TorchGeo-consistent resampling (bilinear for float, nearest for int). Verified: 226 unit tests pass, 12/12 network pixel-comparison tests pass on rasterio 1.4.3.
1 parent 8b8bc7a commit 1a441c4

File tree

5 files changed

+434
-126
lines changed

5 files changed

+434
-126
lines changed

src/rasteret/core/rio_semantics.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright Terrafloww Labs, Inc.
3+
4+
"""Rasterio/GDAL semantic helpers (integration boundary).
5+
6+
Rasteret's core reader operates on a raster's native pixel grid. Some integrations
7+
(notably TorchGeo) define *query-grid* semantics: given bounds + resolution from a
8+
sampler, return pixels on that exact grid.
9+
10+
Rasterio exposes two different, widely-used semantics:
11+
- ``rasterio.merge.merge`` (gdal_merge-style window alignment)
12+
- ``rasterio.warp.reproject`` (warp semantics based on destination pixel centers)
13+
14+
These can differ sometimes in output. TorchGeo's
15+
``RasterDataset`` uses ``merge(bounds=..., res=...)`` in its read path, so Rasteret
16+
needs to match *merge semantics* for TorchGeo interop.
17+
"""
18+
19+
from __future__ import annotations
20+
21+
from dataclasses import dataclass
22+
from typing import Literal
23+
24+
import numpy as np
25+
from affine import Affine
26+
27+
28+
@dataclass(frozen=True)
29+
class MergeGrid:
30+
"""Query grid definition (bounds + resolution)."""
31+
32+
bounds: tuple[float, float, float, float] # (left, bottom, right, top)
33+
res: tuple[float, float] # (xres, yres) in CRS units, positive numbers
34+
35+
@property
36+
def transform(self) -> Affine:
37+
left, _bottom, _right, top = self.bounds
38+
xres, yres = self.res
39+
return Affine(xres, 0.0, float(left), 0.0, -yres, float(top))
40+
41+
@property
42+
def shape(self) -> tuple[int, int]:
43+
"""Match rasterio.merge.merge output shape computation (round)."""
44+
left, bottom, right, top = self.bounds
45+
xres, yres = self.res
46+
width = int(round((right - left) / xres))
47+
height = int(round((top - bottom) / yres))
48+
return height, width
49+
50+
51+
def merge_semantic_resample_single_source(
52+
src_crop: np.ndarray,
53+
*,
54+
src_crop_transform: Affine,
55+
src_full_transform: Affine,
56+
src_full_width: int,
57+
src_full_height: int,
58+
src_crs: int,
59+
grid: MergeGrid,
60+
resampling: Literal["nearest", "bilinear"],
61+
src_nodata: float | int | None,
62+
) -> np.ndarray:
63+
"""Resample *src_crop* onto *grid* using rasterio.merge.merge semantics.
64+
65+
This intentionally delegates the semantics to ``rasterio.merge.merge``.
66+
Rasterio's merge behavior is what TorchGeo uses, and it is known to differ
67+
from warp/reproject behavior by one pixel at extent boundaries.
68+
"""
69+
from rasterio.crs import CRS as RioCRS
70+
from rasterio.enums import Resampling
71+
from rasterio.io import MemoryFile
72+
from rasterio.merge import merge as rio_merge
73+
74+
if src_crop.ndim != 2:
75+
raise ValueError(f"Expected 2-D src_crop, got shape={src_crop.shape}")
76+
77+
# rasterio.merge.merge cannot merge upside-down rasters directly (south-up).
78+
# TorchGeo's merge-based semantics in practice behave like a north-up view.
79+
#
80+
# Normalize south-up sources into a north-up equivalent *without resampling*:
81+
# flip data vertically and adjust transforms to keep georeferencing correct.
82+
if float(src_crop_transform.e) > 0.0:
83+
src_crop = np.ascontiguousarray(src_crop[::-1, :])
84+
src_crop_transform = Affine(
85+
float(src_crop_transform.a),
86+
0.0,
87+
float(src_crop_transform.c),
88+
0.0,
89+
-float(src_crop_transform.e),
90+
float(src_crop_transform.f)
91+
+ float(src_crop.shape[0]) * float(src_crop_transform.e),
92+
)
93+
src_full_transform = Affine(
94+
float(src_full_transform.a),
95+
0.0,
96+
float(src_full_transform.c),
97+
0.0,
98+
-float(src_full_transform.e),
99+
float(src_full_transform.f)
100+
+ float(src_full_height) * float(src_full_transform.e),
101+
)
102+
103+
if (
104+
src_nodata is not None
105+
and isinstance(src_nodata, float)
106+
and np.isnan(src_nodata)
107+
):
108+
# A NaN nodata is only meaningful for floating rasters.
109+
if src_crop.dtype.kind != "f":
110+
src_nodata = None
111+
dst_h, dst_w = grid.shape
112+
if dst_h <= 0 or dst_w <= 0:
113+
return np.zeros((0, 0), dtype=src_crop.dtype)
114+
115+
# Clip the crop to the physical extent of the full raster.
116+
#
117+
# Rasteret's window-mode reader can return a *boundless* crop that extends
118+
# beyond the raster's physical extent (filled with 0/nodata). If we pass
119+
# this boundless crop directly to rasterio.merge.merge, merge will treat
120+
# those extended bounds as valid source bounds and can shift boundary
121+
# behavior (notably for sub-pixel-aligned query grids like NAIP).
122+
#
123+
# We therefore slice away any rows/cols that are fully outside the physical
124+
# extent so the in-memory dataset's bounds match the real raster's bounds.
125+
a = float(src_crop_transform.a)
126+
e = float(src_crop_transform.e)
127+
if a <= 0.0 or e >= 0.0:
128+
raise ValueError(
129+
"Expected a north-up transform for merge semantics " f"(a={a}, e={e})."
130+
)
131+
132+
crop_w = int(src_crop.shape[1])
133+
crop_h = int(src_crop.shape[0])
134+
crop_left = float(src_crop_transform.c)
135+
crop_top = float(src_crop_transform.f)
136+
crop_right = crop_left + crop_w * a
137+
crop_bottom = crop_top + crop_h * e
138+
139+
full_left = float(src_full_transform.c)
140+
full_top = float(src_full_transform.f)
141+
full_right = full_left + int(src_full_width) * float(src_full_transform.a)
142+
full_bottom = full_top + int(src_full_height) * float(src_full_transform.e)
143+
144+
inter_left = max(crop_left, full_left)
145+
inter_right = min(crop_right, full_right)
146+
inter_top = min(crop_top, full_top)
147+
inter_bottom = max(crop_bottom, full_bottom)
148+
149+
if inter_left < inter_right and inter_bottom < inter_top:
150+
import numpy as _np
151+
152+
col0 = int(_np.ceil((inter_left - crop_left) / a - 1e-9))
153+
col1 = int(_np.floor((inter_right - crop_left) / a + 1e-9))
154+
row0 = int(_np.ceil((crop_top - inter_top) / abs(e) - 1e-9))
155+
row1 = int(_np.floor((crop_top - inter_bottom) / abs(e) + 1e-9))
156+
157+
col0 = max(0, min(crop_w, col0))
158+
col1 = max(col0, min(crop_w, col1))
159+
row0 = max(0, min(crop_h, row0))
160+
row1 = max(row0, min(crop_h, row1))
161+
162+
if row0 != 0 or col0 != 0 or row1 != crop_h or col1 != crop_w:
163+
src_crop = src_crop[row0:row1, col0:col1]
164+
src_crop_transform = Affine(
165+
a,
166+
0.0,
167+
crop_left + col0 * a,
168+
0.0,
169+
e,
170+
crop_top + row0 * e,
171+
)
172+
173+
profile = {
174+
"driver": "GTiff",
175+
"height": int(src_crop.shape[0]),
176+
"width": int(src_crop.shape[1]),
177+
"count": 1,
178+
"dtype": src_crop.dtype,
179+
"crs": RioCRS.from_epsg(int(src_crs)),
180+
"transform": src_crop_transform,
181+
**({"nodata": src_nodata} if src_nodata is not None else {}),
182+
}
183+
184+
with MemoryFile() as mem:
185+
with mem.open(**profile) as ds:
186+
ds.write(src_crop, 1)
187+
188+
with mem.open() as ds:
189+
data, _out_transform = rio_merge(
190+
[ds],
191+
bounds=grid.bounds,
192+
res=grid.res,
193+
indexes=[1],
194+
resampling=getattr(Resampling, resampling),
195+
)
196+
return data.squeeze()

src/rasteret/fetch/cog.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,13 @@ async def read_cog(
13001300
[(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax), (xmin, ymin)]
13011301
],
13021302
}
1303+
intersecting_tiles = compute_tile_indices(
1304+
geometry_bbox=geom_bbox,
1305+
transform=metadata.transform,
1306+
tile_size=(metadata.tile_width, metadata.tile_height),
1307+
image_size=(metadata.width, metadata.height),
1308+
debug=debug,
1309+
)
13031310
elif geom_array is not None:
13041311
# Always compute the input bbox in the geometry CRS first. This lets us
13051312
# provide a clear error when the record CRS is missing and the geometry

0 commit comments

Comments
 (0)