|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright Terrafloww Labs, Inc. |
| 3 | + |
| 4 | +"""Rasterio/GDAL semantic helpers (integration boundary). |
| 5 | +
|
| 6 | +Rasteret's core reader operates on a raster's native pixel grid. Some integrations |
| 7 | +(notably TorchGeo) define *query-grid* semantics: given bounds + resolution from a |
| 8 | +sampler, return pixels on that exact grid. |
| 9 | +
|
| 10 | +Rasterio exposes two different, widely-used semantics: |
| 11 | + - ``rasterio.merge.merge`` (gdal_merge-style window alignment) |
| 12 | + - ``rasterio.warp.reproject`` (warp semantics based on destination pixel centers) |
| 13 | +
|
| 14 | +These can differ sometimes in output. TorchGeo's |
| 15 | +``RasterDataset`` uses ``merge(bounds=..., res=...)`` in its read path, so Rasteret |
| 16 | +needs to match *merge semantics* for TorchGeo interop. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +from dataclasses import dataclass |
| 22 | +from typing import Literal |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +from affine import Affine |
| 26 | + |
| 27 | + |
| 28 | +@dataclass(frozen=True) |
| 29 | +class MergeGrid: |
| 30 | + """Query grid definition (bounds + resolution).""" |
| 31 | + |
| 32 | + bounds: tuple[float, float, float, float] # (left, bottom, right, top) |
| 33 | + res: tuple[float, float] # (xres, yres) in CRS units, positive numbers |
| 34 | + |
| 35 | + @property |
| 36 | + def transform(self) -> Affine: |
| 37 | + left, _bottom, _right, top = self.bounds |
| 38 | + xres, yres = self.res |
| 39 | + return Affine(xres, 0.0, float(left), 0.0, -yres, float(top)) |
| 40 | + |
| 41 | + @property |
| 42 | + def shape(self) -> tuple[int, int]: |
| 43 | + """Match rasterio.merge.merge output shape computation (round).""" |
| 44 | + left, bottom, right, top = self.bounds |
| 45 | + xres, yres = self.res |
| 46 | + width = int(round((right - left) / xres)) |
| 47 | + height = int(round((top - bottom) / yres)) |
| 48 | + return height, width |
| 49 | + |
| 50 | + |
| 51 | +def merge_semantic_resample_single_source( |
| 52 | + src_crop: np.ndarray, |
| 53 | + *, |
| 54 | + src_crop_transform: Affine, |
| 55 | + src_full_transform: Affine, |
| 56 | + src_full_width: int, |
| 57 | + src_full_height: int, |
| 58 | + src_crs: int, |
| 59 | + grid: MergeGrid, |
| 60 | + resampling: Literal["nearest", "bilinear"], |
| 61 | + src_nodata: float | int | None, |
| 62 | +) -> np.ndarray: |
| 63 | + """Resample *src_crop* onto *grid* using rasterio.merge.merge semantics. |
| 64 | +
|
| 65 | + This intentionally delegates the semantics to ``rasterio.merge.merge``. |
| 66 | + Rasterio's merge behavior is what TorchGeo uses, and it is known to differ |
| 67 | + from warp/reproject behavior by one pixel at extent boundaries. |
| 68 | + """ |
| 69 | + from rasterio.crs import CRS as RioCRS |
| 70 | + from rasterio.enums import Resampling |
| 71 | + from rasterio.io import MemoryFile |
| 72 | + from rasterio.merge import merge as rio_merge |
| 73 | + |
| 74 | + if src_crop.ndim != 2: |
| 75 | + raise ValueError(f"Expected 2-D src_crop, got shape={src_crop.shape}") |
| 76 | + |
| 77 | + # rasterio.merge.merge cannot merge upside-down rasters directly (south-up). |
| 78 | + # TorchGeo's merge-based semantics in practice behave like a north-up view. |
| 79 | + # |
| 80 | + # Normalize south-up sources into a north-up equivalent *without resampling*: |
| 81 | + # flip data vertically and adjust transforms to keep georeferencing correct. |
| 82 | + if float(src_crop_transform.e) > 0.0: |
| 83 | + src_crop = np.ascontiguousarray(src_crop[::-1, :]) |
| 84 | + src_crop_transform = Affine( |
| 85 | + float(src_crop_transform.a), |
| 86 | + 0.0, |
| 87 | + float(src_crop_transform.c), |
| 88 | + 0.0, |
| 89 | + -float(src_crop_transform.e), |
| 90 | + float(src_crop_transform.f) |
| 91 | + + float(src_crop.shape[0]) * float(src_crop_transform.e), |
| 92 | + ) |
| 93 | + src_full_transform = Affine( |
| 94 | + float(src_full_transform.a), |
| 95 | + 0.0, |
| 96 | + float(src_full_transform.c), |
| 97 | + 0.0, |
| 98 | + -float(src_full_transform.e), |
| 99 | + float(src_full_transform.f) |
| 100 | + + float(src_full_height) * float(src_full_transform.e), |
| 101 | + ) |
| 102 | + |
| 103 | + if ( |
| 104 | + src_nodata is not None |
| 105 | + and isinstance(src_nodata, float) |
| 106 | + and np.isnan(src_nodata) |
| 107 | + ): |
| 108 | + # A NaN nodata is only meaningful for floating rasters. |
| 109 | + if src_crop.dtype.kind != "f": |
| 110 | + src_nodata = None |
| 111 | + dst_h, dst_w = grid.shape |
| 112 | + if dst_h <= 0 or dst_w <= 0: |
| 113 | + return np.zeros((0, 0), dtype=src_crop.dtype) |
| 114 | + |
| 115 | + # Clip the crop to the physical extent of the full raster. |
| 116 | + # |
| 117 | + # Rasteret's window-mode reader can return a *boundless* crop that extends |
| 118 | + # beyond the raster's physical extent (filled with 0/nodata). If we pass |
| 119 | + # this boundless crop directly to rasterio.merge.merge, merge will treat |
| 120 | + # those extended bounds as valid source bounds and can shift boundary |
| 121 | + # behavior (notably for sub-pixel-aligned query grids like NAIP). |
| 122 | + # |
| 123 | + # We therefore slice away any rows/cols that are fully outside the physical |
| 124 | + # extent so the in-memory dataset's bounds match the real raster's bounds. |
| 125 | + a = float(src_crop_transform.a) |
| 126 | + e = float(src_crop_transform.e) |
| 127 | + if a <= 0.0 or e >= 0.0: |
| 128 | + raise ValueError( |
| 129 | + "Expected a north-up transform for merge semantics " f"(a={a}, e={e})." |
| 130 | + ) |
| 131 | + |
| 132 | + crop_w = int(src_crop.shape[1]) |
| 133 | + crop_h = int(src_crop.shape[0]) |
| 134 | + crop_left = float(src_crop_transform.c) |
| 135 | + crop_top = float(src_crop_transform.f) |
| 136 | + crop_right = crop_left + crop_w * a |
| 137 | + crop_bottom = crop_top + crop_h * e |
| 138 | + |
| 139 | + full_left = float(src_full_transform.c) |
| 140 | + full_top = float(src_full_transform.f) |
| 141 | + full_right = full_left + int(src_full_width) * float(src_full_transform.a) |
| 142 | + full_bottom = full_top + int(src_full_height) * float(src_full_transform.e) |
| 143 | + |
| 144 | + inter_left = max(crop_left, full_left) |
| 145 | + inter_right = min(crop_right, full_right) |
| 146 | + inter_top = min(crop_top, full_top) |
| 147 | + inter_bottom = max(crop_bottom, full_bottom) |
| 148 | + |
| 149 | + if inter_left < inter_right and inter_bottom < inter_top: |
| 150 | + import numpy as _np |
| 151 | + |
| 152 | + col0 = int(_np.ceil((inter_left - crop_left) / a - 1e-9)) |
| 153 | + col1 = int(_np.floor((inter_right - crop_left) / a + 1e-9)) |
| 154 | + row0 = int(_np.ceil((crop_top - inter_top) / abs(e) - 1e-9)) |
| 155 | + row1 = int(_np.floor((crop_top - inter_bottom) / abs(e) + 1e-9)) |
| 156 | + |
| 157 | + col0 = max(0, min(crop_w, col0)) |
| 158 | + col1 = max(col0, min(crop_w, col1)) |
| 159 | + row0 = max(0, min(crop_h, row0)) |
| 160 | + row1 = max(row0, min(crop_h, row1)) |
| 161 | + |
| 162 | + if row0 != 0 or col0 != 0 or row1 != crop_h or col1 != crop_w: |
| 163 | + src_crop = src_crop[row0:row1, col0:col1] |
| 164 | + src_crop_transform = Affine( |
| 165 | + a, |
| 166 | + 0.0, |
| 167 | + crop_left + col0 * a, |
| 168 | + 0.0, |
| 169 | + e, |
| 170 | + crop_top + row0 * e, |
| 171 | + ) |
| 172 | + |
| 173 | + profile = { |
| 174 | + "driver": "GTiff", |
| 175 | + "height": int(src_crop.shape[0]), |
| 176 | + "width": int(src_crop.shape[1]), |
| 177 | + "count": 1, |
| 178 | + "dtype": src_crop.dtype, |
| 179 | + "crs": RioCRS.from_epsg(int(src_crs)), |
| 180 | + "transform": src_crop_transform, |
| 181 | + **({"nodata": src_nodata} if src_nodata is not None else {}), |
| 182 | + } |
| 183 | + |
| 184 | + with MemoryFile() as mem: |
| 185 | + with mem.open(**profile) as ds: |
| 186 | + ds.write(src_crop, 1) |
| 187 | + |
| 188 | + with mem.open() as ds: |
| 189 | + data, _out_transform = rio_merge( |
| 190 | + [ds], |
| 191 | + bounds=grid.bounds, |
| 192 | + res=grid.res, |
| 193 | + indexes=[1], |
| 194 | + resampling=getattr(Resampling, resampling), |
| 195 | + ) |
| 196 | + return data.squeeze() |
0 commit comments