diff --git a/design_notes/raster_index.md b/design_notes/raster_index.md new file mode 100644 index 0000000..3a0ac1c --- /dev/null +++ b/design_notes/raster_index.md @@ -0,0 +1,137 @@ +# Design for a RasterIndex + +## TL;DR + +1. We propose designing a RasterIndex that can handle many ways of expressing a raster → model space transformation. +2. We propose that this RasterIndex _record_ the information for this transformation (e.g. `GeoTransform`) *internally* and remove that information from the dataset when constructed. +3. Since the information is recorded internally, a user must intentionally write the transformation back to the dataset, destroying the index while doing so, and then write to disk. + +## Goals: + +### UX goals: + +1. Make it easy to read a GeoTIFF with CRS and raster -> model space transformation information in to Xarray with appropriate indexes. There are at least two indexes: one associated with the CRS; and one with the transformation. +2. The raster ↔ model transformation information can be ambiguous, so an explicit API should be provided. + 1. [RPCs](http://geotiff.maptools.org/rpc_prop.html): + > The RPC model in a GeoTIFF file is supplementary to all other GeoTIFF tags and not directly related. That is, it is possible to have a conventional set of GeoTIFF tags (such as a tiepoint + pixel scale + projected coordinate system description) along with the RPCCoefficientTag. The RPCCoefficientTag is always describing a transformation to WGS84, regardless of what geographic coordinate system might be described in the coordinate system description tags of the GeoTIFF file. It is also possible to have only the RPCCoefficientTag tag and no other GeoTIFF tags. + 2. [GeoTransform is not in the GeoTIFF standard](https://docs.ogc.org/is/19-008r4/19-008r4.html). Instead that uses ModelTiepointTag, ModelPixelScaleTag, ModelTransformationTag. + 3. [GCPs are ambiguous](https://gdal.org/en/stable/user/raster_data_model.html#gcps): + > The GDAL data model does not imply a transformation mechanism that must be generated from the GCPs … this is left to the application. However 1st to 5th order polynomials are common. + 4. [And for extra fun](https://gdal.org/en/stable/user/raster_data_model.html#gcps): + > Normally a dataset will contain either an affine geotransform, GCPs or neither. It is uncommon to have both, and it is undefined which is authoritative. + +### Index design goals: + +1. The CRS Index allows us to assert CRS compliance during alignment. This is provided by XProj. +2. The transform index should allow us to: + 1. Do accurate alignment in pixel space unaffected by floating-point inaccuracies in model-space; + 2. All possible transforms have **offsets** which means they need to be kept up-to-date during slicing. + 3. Allow extracting metadata necessary to accurately represent the information on disk. + +## Some complications + +### Handling projections +1. There are at least 5 ways to record a raster ⇒ model space transformation. +2. Information for these transforms may be stored in many places depending on the reading library: + 1. `ds.spatial_ref.attrs`(rioxarray stores GeoTransform, gcps here) + 2. `ds.band_data.attrs` (rioxarray stores TIEPOINTS here) + 3. `ds.attrs` possibly since this would be the most direct way to map TIFF tags + 4. It seems possible to store RPCs as either arrays or as attrs. + +We'd like a design that is extensible to handle all 5 (or more) cases; again suggesting an explicit lower-level API. + +### Composing with CRSIndex + +[One unanswered question so far is](https://github.com/benbovy/xproj/issues/22#issuecomment-2789459387) +> I think the major decision is "who handles the spatial_ref / grid mapping variable". Should it be exclusively handled by xproj.CRSIndex? Or should it be bound to a geospatial index such as rasterix.RasterIndex, xvec.GeometryIndex, xoak.S2PointIndex, etc.? + +i.e. does the Index wrap CRSIndex too, or compose with it. +Some points: +1. reprojection requires handling both the transform and the CRS. +2. the EDR-like selection API is similar. + +Importantly, GDAL chooses to write GeoTransform to the `grid_mapping` variable in netCDF which _also_ records CRS information. + +This gives us two options: +1. RasterIndex wraps CRSIndex too and handles everything. +2. RasterIndex extracts projection information, however it is stored (e.g. GeoTransform), and tracks it internally. Any functionality that relies on both the transformation and CRS will need to be built as an accessor layer. + +Below is a proposal for (2). + +## Proposal for transform index + +We design RasterIndex as a wrapper around **one** of many transform based indexes: +1. AffineTransformIndex ↔ GeoTransform +2. ModelTransformationIndex ↔ ModelTransformationTag +3. ModelTiepointScaleIndex ↔ ModelTiepointTag + ModelPixelScaleTag +4. GCPIndex ↔ Ground Control Points +5. RPCIndex ↔ Rational Polynomial Coefficients +6. Subsampled auxiliary coordinates, detailed in [CF section 8.3](https://cfconventions.org/Data/cf-conventions/cf-conventions-1.12/cf-conventions.html#compression-by-coordinate-subsampling) and equivalent to GDAL's [geolocation arrays](https://gdal.org/en/stable/development/rfc/rfc4_geolocate.html) with `PIXEL_STEP` and/or LINE_STEP` > 1. + +Each of the wrapped index has an associated transform: +```python +@dataclass +class RasterTransform: + rpc: RPC | None # rpcs + tiepoint_scale : ModelTiepointAndScale | None # ModelTiepointTag, ModelPixelScaleTag + gcps: GroundControlPoints | None + transformation : ModelTransformation | None + geotransform : Affine | None # GeoTransform + + def from_geotransform(attrs: dict) -> Self: + ... + + def from_tiepoints(attrs: dict) -> Self: + ... + + def from_gcps(gcps: ?) -> Self: + ... + + def from_rpcs(?) -> Self: + ... + + def from_geolocation_arrays(?) -> Self: + ... +``` + +### Read-time +These transforms are constructed by **popping** the relevant information from a user-provided source. +This is analogous to an "encode/decode" workflow we currently have in Xarray. +```python +transform = rasterix.RasterTransform.from_geotransform(ds.spatial_ref.attrs) +# transform = rasterix.RasterTransform.from_tiepoints(ds.band_data.attrs) +``` + +By **popping** the information, the transformation is stored _only_ on the index, and must be rewritten back to the dataset by the user when writing to the disk. This is the RioXarray pattern. + +Once a transform is constructed, we could do +```python + index = RasterIndex.from_transform(transform, dims_and_sizes=...) + ds = ds.assign_coords(xr.Coordinates.from_xindex(index)) + ds = ds.set_xindex("spatial_ref", xproj.CRSIndex) +``` +### Write-time + +Before write, we must write the transform (similar to rioxarray) and _destroy_ the RasterIndex instance. +This seems almost required since there are many ways of recording the transformation information on the dataset; and many of these approaches might be used in the same dataset. +```python + # encodes the internal RasterTransform in attributes of the `to` variable + # destroys the RasterIndex + ds = ds.rasterix.write_transform(to: Hashable | rasterix.SELF, formats=["geotransform", "tiepoint"]) + ds.rio.to_raster() +``` +Here: +1. `SELF` could mean write to the object (Dataset | DataArray) attrs +2. `formats` allow you to record the same information in multiple ways +3. `.rio.write_transform` could just dispatch to this method. + + +## Appendix + +### Encode/decode workflow for subsampled coordinates + +Taking the example 8.3 from [CF section 8.3](https://cfconventions.org/Data/cf-conventions/cf-conventions-1.12/cf-conventions.html#compression-by-coordinate-subsampling), the decode step may consist in: +1. turn the tie point coordinate variables `lat(tp_yc, tp_xc)` and `lon(tp_yc, tp_xc)` into `lat(yc, xc)` and `lon(yc, xc)` in to Xarray coordinates associated with a custom transformation index that stores only the tie points. In other words, uncompress the dimensions of the `lat` & `lon` coordinates without uncompressing their data. +2. also remove the tie point index variables and the interpolation variable, and track their data / metadata internally in the index + +The encode step would then consist in restoring the compressed tie point coordinate & index variables as well as the interpolation variable. diff --git a/notebooks/examples.ipynb b/notebooks/examples.ipynb new file mode 100644 index 0000000..bbb208b --- /dev/null +++ b/notebooks/examples.ipynb @@ -0,0 +1,1046 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ddbd1512-2c4e-456c-b1a3-2441e9461da4", + "metadata": {}, + "source": [ + "# Raster dataset examples\n", + "\n", + "Written by ChatGPT.\n", + "\n", + "## GeoTiffs with GCPs and Tiepoints" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c6db7886-52e0-4fb3-90ec-a4e58f27909a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GeoTIFF with GCPs and Tie Points saved as output_with_gcps_tiepoints.tif\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/deepak/miniforge3/envs/raster/lib/python3.12/site-packages/rasterio/__init__.py:366: NotGeoreferencedWarning: The given matrix is equal to Affine.identity or its flipped counterpart. GDAL may ignore this matrix and save no geotransform without raising an error. This behavior is somewhat driver-specific.\n", + " dataset = writer(\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import rasterio\n", + "from rasterio.control import GroundControlPoint\n", + "from rasterio.transform import Affine\n", + "\n", + "# Define file path for output GeoTIFF\n", + "output_tiff = \"output_with_gcps_tiepoints.tif\"\n", + "\n", + "# Create a dummy raster (100x100) with random values\n", + "width, height = 100, 100\n", + "data = np.random.randint(0, 255, (1, height, width), dtype=np.uint8)\n", + "\n", + "# Define Affine transformation (default identity)\n", + "transform = Affine.translation(0, 0) * Affine.scale(1, -1)\n", + "\n", + "# Define Ground Control Points (GCPs)\n", + "gcps = [\n", + " GroundControlPoint(row=10, col=20, x=-122.123, y=37.456, z=50),\n", + " GroundControlPoint(row=30, col=40, x=-122.456, y=37.789, z=60),\n", + " GroundControlPoint(row=50, col=60, x=-122.789, y=38.123, z=55),\n", + "]\n", + "\n", + "# Define Tie Points (Model Tiepoints)\n", + "tie_points = [\n", + " (0, 0, 0, -122.1, 37.4, 0),\n", + " (50, 50, 0, -122.2, 37.5, 0),\n", + "]\n", + "\n", + "# Create a new GeoTIFF with GCPs\n", + "with rasterio.open(\n", + " output_tiff,\n", + " \"w\",\n", + " driver=\"GTiff\",\n", + " height=height,\n", + " width=width,\n", + " count=1,\n", + " dtype=rasterio.uint8,\n", + " crs=\"EPSG:4326\", # Define coordinate reference system (WGS84)\n", + " transform=transform,\n", + ") as dst:\n", + " dst.write(data)\n", + "\n", + " # Add Ground Control Points\n", + " dst.gcps = (gcps, dst.crs)\n", + "\n", + " # Add Tie Points as metadata\n", + " dst.update_tags(TIEPOINTS=str(tie_points))\n", + "\n", + "print(f\"GeoTIFF with GCPs and Tie Points saved as {output_tiff}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "58b89f55-ca15-4ac6-ac39-ccdf87bb3853", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 40kB\n",
+       "Dimensions:      (band: 1, y: 100, x: 100)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B ...\n",
+       "Dimensions without coordinates: y, x\n",
+       "Data variables:\n",
+       "    band_data    (band, y, x) float32 40kB ...
" + ], + "text/plain": [ + " Size: 40kB\n", + "Dimensions: (band: 1, y: 100, x: 100)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B ...\n", + "Dimensions without coordinates: y, x\n", + "Data variables:\n", + " band_data (band, y, x) float32 40kB ..." + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xarray as xr\n", + "\n", + "xr.open_dataset(\"output_with_gcps_tiepoints.tif\", engine=\"rasterio\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d50141f6-5eb2-49ca-819f-6d65adbf00c6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GeoTransform (Affine Transformation):\n", + "| 1.00, 0.00, 0.00|\n", + "| 0.00, 1.00, 0.00|\n", + "| 0.00, 0.00, 1.00|\n", + "([GroundControlPoint(row=10.0, col=20.0, x=-122.123, y=37.456, z=50.0, id='1', info=''), GroundControlPoint(row=30.0, col=40.0, x=-122.456, y=37.789, z=60.0, id='2', info=''), GroundControlPoint(row=50.0, col=60.0, x=-122.789, y=38.123, z=55.0, id='3', info='')], CRS.from_wkt('GEOGCS[\"WGS 84\",DATUM[\"WGS_1984\",SPHEROID[\"WGS 84\",6378137,298.257223563,AUTHORITY[\"EPSG\",\"7030\"]],AUTHORITY[\"EPSG\",\"6326\"]],PRIMEM[\"Greenwich\",0,AUTHORITY[\"EPSG\",\"8901\"]],UNIT[\"degree\",0.0174532925199433,AUTHORITY[\"EPSG\",\"9122\"]],AXIS[\"Latitude\",NORTH],AXIS[\"Longitude\",EAST],AUTHORITY[\"EPSG\",\"4326\"]]'))\n", + "{'TIEPOINTS': '[(0, 0, 0, -122.1, 37.4, 0), (50, 50, 0, -122.2, 37.5, 0)]', 'AREA_OR_POINT': 'Area'}\n" + ] + } + ], + "source": [ + "import rasterio\n", + "\n", + "# Open the saved GeoTIFF\n", + "output_tiff = \"output_with_gcps_tiepoints.tif\"\n", + "\n", + "with rasterio.open(output_tiff) as dataset:\n", + " print(\"GeoTransform (Affine Transformation):\")\n", + " print(dataset.transform)\n", + " print(dataset.gcps)\n", + " print(dataset.tags())" + ] + }, + { + "cell_type": "markdown", + "id": "db578a77-4290-4bf8-8840-69d980c3e0b3", + "metadata": {}, + "source": [ + "## GDAL netCDF with GeoTransform\n", + "\n", + "rasterio cannot write to netCDF" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4d63ac6e-4ee9-4223-b2cb-2dc9560ef157", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/deepak/miniforge3/envs/raster/lib/python3.12/site-packages/osgeo/osr.py:410: FutureWarning: Neither osr.UseExceptions() nor osr.DontUseExceptions() has been explicitly called. In GDAL 4.0, exceptions will be enabled by default.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NetCDF file 'output.nc' created with geotransform and projection.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from osgeo import gdal, osr\n", + "\n", + "# Define the output netCDF file\n", + "output_file = \"output.nc\"\n", + "\n", + "# Define raster dimensions\n", + "cols = 10 # Number of columns\n", + "rows = 10 # Number of rows\n", + "bands = 1 # Number of bands\n", + "\n", + "# Define geotransform (Origin X, Pixel Width, Rotation, Origin Y, Rotation, Pixel Height)\n", + "geotransform = [100.0, 30.0, 0.0, 200.0, 0.0, -30.0] # Example values\n", + "\n", + "# Define a spatial reference (WGS84 in this case)\n", + "srs = osr.SpatialReference()\n", + "srs.ImportFromEPSG(4326) # WGS84\n", + "proj_wkt = srs.ExportToWkt()\n", + "\n", + "# Create a new netCDF file using GDAL\n", + "driver = gdal.GetDriverByName(\"netCDF\")\n", + "dataset = driver.Create(output_file, cols, rows, bands, gdal.GDT_Float32)\n", + "\n", + "# Set geotransform and projection\n", + "dataset.SetGeoTransform(geotransform)\n", + "dataset.SetProjection(proj_wkt)\n", + "\n", + "# Create some dummy data (a simple gradient)\n", + "data = np.arange(cols * rows, dtype=np.float32).reshape(rows, cols)\n", + "\n", + "# Write data to the first band\n", + "band = dataset.GetRasterBand(1)\n", + "band.WriteArray(data)\n", + "band.SetNoDataValue(-9999) # Set a no-data value\n", + "\n", + "# Flush and close dataset\n", + "band.FlushCache()\n", + "dataset = None\n", + "\n", + "print(f\"NetCDF file '{output_file}' created with geotransform and projection.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1501beb4-19de-4abb-8953-e1d6dcf5b46d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 561B\n",
+       "Dimensions:  (lat: 10, lon: 10)\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float64 80B -85.0 -55.0 -25.0 5.0 ... 95.0 125.0 155.0 185.0\n",
+       "  * lon      (lon) float64 80B 115.0 145.0 175.0 205.0 ... 325.0 355.0 385.0\n",
+       "Data variables:\n",
+       "    Band1    (lat, lon) float32 400B ...\n",
+       "    crs      |S1 1B ...\n",
+       "Attributes:\n",
+       "    Conventions:  CF-1.5\n",
+       "    GDAL:         GDAL 3.10.1, released 2025/01/08\n",
+       "    history:      Wed Apr 09 20:27:59 2025: GDAL Create( output.nc, ... )
" + ], + "text/plain": [ + " Size: 561B\n", + "Dimensions: (lat: 10, lon: 10)\n", + "Coordinates:\n", + " * lat (lat) float64 80B -85.0 -55.0 -25.0 5.0 ... 95.0 125.0 155.0 185.0\n", + " * lon (lon) float64 80B 115.0 145.0 175.0 205.0 ... 325.0 355.0 385.0\n", + "Data variables:\n", + " Band1 (lat, lon) float32 400B ...\n", + " crs |S1 1B ...\n", + "Attributes:\n", + " Conventions: CF-1.5\n", + " GDAL: GDAL 3.10.1, released 2025/01/08\n", + " history: Wed Apr 09 20:27:59 2025: GDAL Create( output.nc, ... )" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xarray as xr\n", + "\n", + "xr.open_dataset(\"output.nc\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:raster]", + "language": "python", + "name": "conda-env-raster-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/test_raster_index.ipynb b/notebooks/test_raster_index.ipynb new file mode 100644 index 0000000..9932447 --- /dev/null +++ b/notebooks/test_raster_index.ipynb @@ -0,0 +1,3750 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fcc18135-13cb-4fd8-b04b-c5cc9836ee74", + "metadata": {}, + "source": [ + "# RioXarray raster index examples" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1b4ca507-2fa4-4809-a66e-28a907eda00e", + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import rioxarray as rio\n", + "\n", + "xr.set_options(display_expand_indexes=True);" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "efcd1d28-ed85-46a0-8729-f90bb2b63070", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: rasterix pyproject\n", + "import sys\n", + "sys.path.append(\"..\")\n", + "from rasterix import RasterIndex" + ] + }, + { + "cell_type": "markdown", + "id": "9cbecd44-d3a7-4efa-8ac1-9840788bed81", + "metadata": {}, + "source": [ + "## Example with rectilinear and no rotation affine transform\n", + "\n", + "Both x and y coordinates are 1-dimensional." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "01685f92-546b-454a-908f-12e64c1a4631", + "metadata": {}, + "outputs": [], + "source": [ + "source = \"/vsicurl/https://noaadata.apps.nsidc.org/NOAA/G02135/south/daily/geotiff/2024/01_Jan/S_20240101_concentration_v3.0.tif\"" + ] + }, + { + "cell_type": "markdown", + "id": "b12bfaff-fe20-43d1-b827-812a6299b87f", + "metadata": {}, + "source": [ + "### Load and inspect the datasets, with and without `RasterIndex`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "628796f4-0cd8-405d-99de-ff7a1de976ef", + "metadata": {}, + "outputs": [], + "source": [ + "da_no_raster_index = xr.open_dataarray(source, engine=\"rasterio\")\n", + "\n", + "\n", + "def set_raster_index(obj):\n", + " \"\"\"Return a new DataArray or Dataset with a RasterIndex.\"\"\"\n", + " x_dim = obj.rio.x_dim\n", + " y_dim = obj.rio.y_dim\n", + "\n", + " index = RasterIndex.from_transform(\n", + " obj.rio.transform(),\n", + " obj.sizes[x_dim],\n", + " obj.sizes[y_dim],\n", + " x_dim=x_dim,\n", + " y_dim=y_dim,\n", + " )\n", + "\n", + " # drop-in replacement of explicit x/y coordinates for now\n", + " coords = xr.Coordinates.from_xindex(index)\n", + " return obj.assign_coords(coords)\n", + "\n", + "\n", + "da_raster_index = set_raster_index(da_no_raster_index)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "68328dfd-2885-4bd9-8346-44c2f01d5deb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, y: 332, x: 316)> Size: 420kB\n",
+       "[104912 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "  * x            (x) float64 3kB -3.938e+06 -3.912e+06 ... 3.912e+06 3.938e+06\n",
+       "  * y            (y) float64 3kB 4.338e+06 4.312e+06 ... -3.912e+06 -3.938e+06\n",
+       "    spatial_ref  int64 8B ...\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 420kB\n", + "[104912 values with dtype=float32]\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " * x (x) float64 3kB -3.938e+06 -3.912e+06 ... 3.912e+06 3.938e+06\n", + " * y (y) float64 3kB 4.338e+06 4.312e+06 ... -3.912e+06 -3.938e+06\n", + " spatial_ref int64 8B ...\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_no_raster_index" + ] + }, + { + "cell_type": "markdown", + "id": "28f1f83c-37e7-478f-b676-9800c7eae768", + "metadata": {}, + "source": [ + "The \"x\" and \"y\" coordinates with a raster index are lazy! The repr below shows a few values for each coordinate (those have been computed on-the-fly) but clicking on the database icon doesn't show any value in the spatial coordinate data reprs." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5f07a272-895f-40ac-8ae3-07a8d3e07521", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, y: 332, x: 316)> Size: 420kB\n",
+       "[104912 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B ...\n",
+       "  * x            (x) float64 3kB -3.938e+06 -3.912e+06 ... 3.912e+06 3.938e+06\n",
+       "  * y            (y) float64 3kB 4.338e+06 4.312e+06 ... -3.912e+06 -3.938e+06\n",
+       "Indexes:\n",
+       "  ┌ x        RasterIndex\n",
+       "  └ y\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 420kB\n", + "[104912 values with dtype=float32]\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B ...\n", + " * x (x) float64 3kB -3.938e+06 -3.912e+06 ... 3.912e+06 3.938e+06\n", + " * y (y) float64 3kB 4.338e+06 4.312e+06 ... -3.912e+06 -3.938e+06\n", + "Indexes:\n", + " ┌ x RasterIndex\n", + " └ y\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_raster_index" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9ab08edf-c413-451b-9ddb-b9b894ea139c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "xarray.core.indexing.CoordinateTransformIndexingAdapter" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(da_raster_index.coords.variables[\"x\"]._data)" + ] + }, + { + "cell_type": "markdown", + "id": "81ccec5b-cfbc-4532-846d-725f62a9fb9b", + "metadata": {}, + "source": [ + "### Compare and align the datasets with and without `RasterIndex`\n", + "\n", + "`equals` compares variable values without relying on Xarray coordinate indexes. Both dataarrays should thus be equal." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "95309ec5-b317-4171-85ba-bd01cc9bd649", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_raster_index.equals(da_no_raster_index)" + ] + }, + { + "cell_type": "markdown", + "id": "faa19ff9-404f-4885-90a6-5188c6c2e585", + "metadata": {}, + "source": [ + "Xarray alignment relies on Xarray coordinate indexes. Trying to align both datasets fails here since they each have different index types.\n", + "\n", + "Maybe Xarray should try aligning the datasets based on coordinate variable data in this case? I don't think this would be easy to implement in a general way... Xarray's alignment logic is already complex! Also the alignment failure here is not necessarily a bad thing (cf. alignment issues with explicit floating-point coordinates)." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ac2d1532-2441-4a86-8339-a4d5ad78dd01", + "metadata": {}, + "outputs": [], + "source": [ + "# this fails!\n", + "\n", + "# da_raster_index + da_no_raster_index" + ] + }, + { + "cell_type": "markdown", + "id": "909e425e-178b-401a-b754-f198f59dce5e", + "metadata": {}, + "source": [ + "### Indexing the dataarray with `RasterIndex`" + ] + }, + { + "cell_type": "markdown", + "id": "e0bd8ecb-25c1-4ce9-9af8-b354b636d445", + "metadata": {}, + "source": [ + "#### Integer-based selection (isel)" + ] + }, + { + "cell_type": "markdown", + "id": "6832b3b2-c632-492b-ba51-582c010384eb", + "metadata": {}, + "source": [ + "- *Slicing both x and y*" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f2d90e00-72c5-4312-95e8-b9a23fe5d5e3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, y: 166, x: 3)> Size: 2kB\n",
+       "array([[[0., 0., 0.],\n",
+       "        [0., 0., 0.],\n",
+       "        ...,\n",
+       "        [0., 0., 0.],\n",
+       "        [0., 0., 0.]]], shape=(1, 166, 3), dtype=float32)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B 0\n",
+       "  * x            (x) float64 24B -3.912e+06 -3.888e+06 -3.862e+06\n",
+       "  * y            (y) float64 1kB 4.338e+06 4.288e+06 ... -3.862e+06 -3.912e+06\n",
+       "Indexes:\n",
+       "  ┌ x        RasterIndex\n",
+       "  └ y\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 2kB\n", + "array([[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]], shape=(1, 166, 3), dtype=float32)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B 0\n", + " * x (x) float64 24B -3.912e+06 -3.888e+06 -3.862e+06\n", + " * y (y) float64 1kB 4.338e+06 4.288e+06 ... -3.862e+06 -3.912e+06\n", + "Indexes:\n", + " ┌ x RasterIndex\n", + " └ y\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_sliced = da_raster_index.isel(x=slice(1, 4), y=slice(None, None, 2))\n", + "da_sliced" + ] + }, + { + "cell_type": "markdown", + "id": "baf29b0e-3062-426d-920d-d0b293143422", + "metadata": {}, + "source": [ + "Slicing keeps both coordinates lazy (it computes a new affine transform):" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "513ebddc-bc55-4893-8f80-881e1bfa56c2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RasterIndex\n", + "'x':\n", + " \n", + "'y':\n", + " \n", + "RasterIndex\n", + "'x':\n", + " \n", + "'y':\n", + " \n" + ] + } + ], + "source": [ + "print(da_sliced.xindexes[\"x\"])\n", + "print(da_sliced.xindexes[\"y\"])" + ] + }, + { + "cell_type": "markdown", + "id": "dbafcfdc-4ffe-4c0c-985d-4f760a81b0bd", + "metadata": {}, + "source": [ + "- *Outer indexing with arbitrary array values*" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "be610609-4eed-4971-8fb3-bf1fc324d128", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, y: 3, x: 3)> Size: 36B\n",
+       "array([[[0., 0., 0.],\n",
+       "        [0., 0., 0.],\n",
+       "        [0., 0., 0.]]], dtype=float32)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B 0\n",
+       "  * x            (x) float64 24B -3.938e+06 -3.888e+06 -3.838e+06\n",
+       "  * y            (y) float64 24B 4.338e+06 4.338e+06 4.312e+06\n",
+       "Indexes:\n",
+       "  ┌ x        RasterIndex\n",
+       "  └ y\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 36B\n", + "array([[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]], dtype=float32)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B 0\n", + " * x (x) float64 24B -3.938e+06 -3.888e+06 -3.838e+06\n", + " * y (y) float64 24B 4.338e+06 4.338e+06 4.312e+06\n", + "Indexes:\n", + " ┌ x RasterIndex\n", + " └ y\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_outer = da_raster_index.isel(x=[0, 2, 4], y=[0, 0, 1])\n", + "da_outer" + ] + }, + { + "cell_type": "markdown", + "id": "ce3502c0-3af5-4470-ba58-b46743ab3d7e", + "metadata": {}, + "source": [ + "We cannot compute a new affine transfrom given arbitrary array positions. To allow further data selection, pandas indexes are created for indexed spatial coordinates:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9695d517-99f8-4708-85a0-0aedc7a0386b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RasterIndex\n", + "'x':\n", + " PandasIndex(Index([-3937500.0, -3887500.0, -3837500.0], dtype='float64', name='x'))\n", + "'y':\n", + " PandasIndex(Index([4337500.0, 4337500.0, 4312500.0], dtype='float64', name='y'))\n", + "RasterIndex\n", + "'x':\n", + " PandasIndex(Index([-3937500.0, -3887500.0, -3837500.0], dtype='float64', name='x'))\n", + "'y':\n", + " PandasIndex(Index([4337500.0, 4337500.0, 4312500.0], dtype='float64', name='y'))\n" + ] + } + ], + "source": [ + "print(da_outer.xindexes[\"x\"])\n", + "print(da_outer.xindexes[\"y\"])" + ] + }, + { + "cell_type": "markdown", + "id": "a1b7a75b-94dc-4fad-b901-a7d2e563c5ff", + "metadata": {}, + "source": [ + "- *Basic indexing with scalars*" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "69a94f72-781d-412d-b02f-8068bef2e0a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1)> Size: 4B\n",
+       "array([0.], dtype=float32)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B 0\n",
+       "    x            float64 8B -3.938e+06\n",
+       "    y            float64 8B 4.312e+06\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 4B\n", + "array([0.], dtype=float32)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B 0\n", + " x float64 8B -3.938e+06\n", + " y float64 8B 4.312e+06\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_scalar = da_raster_index.isel(x=0, y=1)\n", + "da_scalar" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "55148b6b-6f4c-4d2b-aded-1c7895a65bb5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, y: 332)> Size: 1kB\n",
+       "array([[0., 0., 0., ..., 0., 0., 0.]], shape=(1, 332), dtype=float32)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B 0\n",
+       "    x            float64 8B -3.938e+06\n",
+       "    y            (y) float64 3kB 4.338e+06 4.312e+06 ... -3.912e+06 -3.938e+06\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 1kB\n", + "array([[0., 0., 0., ..., 0., 0., 0.]], shape=(1, 332), dtype=float32)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B 0\n", + " x float64 8B -3.938e+06\n", + " y (y) float64 3kB 4.338e+06 4.312e+06 ... -3.912e+06 -3.938e+06\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_xscalar = da_raster_index.isel(x=0)\n", + "da_xscalar" + ] + }, + { + "cell_type": "markdown", + "id": "859268ae-cb9f-49bb-ab4d-a716af7e7f1f", + "metadata": {}, + "source": [ + "**FIXME** The RasterIndex should be preserved in case of partial dimension reduction." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8ad95332-8643-462d-a395-c34deaa06f28", + "metadata": {}, + "outputs": [], + "source": [ + "# da_xscalar.xindexes[\"y\"] # should return an index" + ] + }, + { + "cell_type": "markdown", + "id": "1402935c-3e70-4ec3-a2e2-8912158a0b11", + "metadata": {}, + "source": [ + "- *Vectorized (fancy) indexing*\n", + "\n", + "Indexing the spatial coordinates with Xarray `Variable` objects returns a `RasterIndex` (wrapping `PandasIndex`) for 1-dimensional variables and no index for scalar or n-dimensional variables." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8e707c5b-6f7b-4027-be1b-683b10d24179", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, z: 2)> Size: 8B\n",
+       "array([[0., 0.]], dtype=float32)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B 0\n",
+       "  * x            (z) float64 16B -3.938e+06 -3.912e+06\n",
+       "  * y            (z) float64 16B 4.312e+06 4.312e+06\n",
+       "Dimensions without coordinates: z\n",
+       "Indexes:\n",
+       "  ┌ x        RasterIndex\n",
+       "  └ y\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 8B\n", + "array([[0., 0.]], dtype=float32)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B 0\n", + " * x (z) float64 16B -3.938e+06 -3.912e+06\n", + " * y (z) float64 16B 4.312e+06 4.312e+06\n", + "Dimensions without coordinates: z\n", + "Indexes:\n", + " ┌ x RasterIndex\n", + " └ y\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_points = da_raster_index.isel(x=xr.Variable(\"z\", [0, 1]), y=xr.Variable(\"z\", [1, 1]))\n", + "da_points" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "36cafddb-a167-4e7f-97d6-f4b3420f1597", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'band_data' (band: 1, u: 2, v: 2)> Size: 16B\n",
+       "array([[[0., 0.],\n",
+       "        [0., 0.]]], dtype=float32)\n",
+       "Coordinates:\n",
+       "  * band         (band) int64 8B 1\n",
+       "    spatial_ref  int64 8B 0\n",
+       "    x            (u, v) float64 32B -3.938e+06 -3.912e+06 -3.888e+06 -3.862e+06\n",
+       "    y            (u, v) float64 32B 4.312e+06 4.312e+06 4.288e+06 4.288e+06\n",
+       "Dimensions without coordinates: u, v\n",
+       "Attributes:\n",
+       "    AREA_OR_POINT:  Area
" + ], + "text/plain": [ + " Size: 16B\n", + "array([[[0., 0.],\n", + " [0., 0.]]], dtype=float32)\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " spatial_ref int64 8B 0\n", + " x (u, v) float64 32B -3.938e+06 -3.912e+06 -3.888e+06 -3.862e+06\n", + " y (u, v) float64 32B 4.312e+06 4.312e+06 4.288e+06 4.288e+06\n", + "Dimensions without coordinates: u, v\n", + "Attributes:\n", + " AREA_OR_POINT: Area" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da_points2d = da_raster_index.isel(\n", + " x=xr.Variable((\"u\", \"v\"), [[0, 1], [2, 3]]),\n", + " y=xr.Variable((\"u\", \"v\"), [[1, 1], [2, 2]]),\n", + ")\n", + "da_points2d" + ] + }, + { + "cell_type": "markdown", + "id": "5e9c553c-a056-4d31-9102-51c44ac1cd37", + "metadata": {}, + "source": [ + "#### Label-based selection (sel)" + ] + }, + { + "cell_type": "markdown", + "id": "5ecd7d5a-a20e-4cfd-bbf2-87fac330bd08", + "metadata": {}, + "source": [ + "TODO" + ] + }, + { + "cell_type": "markdown", + "id": "5bdbb1ae-2715-47b7-9743-144fabd335d3", + "metadata": {}, + "source": [ + "## Example with complex affine transformation\n", + "\n", + "x and y coordinates are both 2-dimensional." + ] + }, + { + "cell_type": "markdown", + "id": "fb02ebda-7a09-48b2-b95d-7e78a3fd147c", + "metadata": {}, + "source": [ + "TODO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b468192f-f199-4c0f-a4d3-0aa33e2ad873", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (Pixi)", + "language": "python", + "name": "pixi-kernel-python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/rasterix/__init__.py b/rasterix/__init__.py new file mode 100644 index 0000000..bb595b0 --- /dev/null +++ b/rasterix/__init__.py @@ -0,0 +1,4 @@ +from .raster_index import RasterIndex + + +__all__ = ["RasterIndex"] diff --git a/rasterix/raster_index.py b/rasterix/raster_index.py new file mode 100644 index 0000000..cbe9d88 --- /dev/null +++ b/rasterix/raster_index.py @@ -0,0 +1,435 @@ +import textwrap +from collections.abc import Hashable, Mapping +from typing import Any + +import numpy as np +import pandas as pd +from affine import Affine +from xarray import DataArray, Index, Variable +from xarray.core.coordinate_transform import CoordinateTransform + +# TODO: import from public API once it is available +from xarray.core.indexes import CoordinateTransformIndex, PandasIndex +from xarray.core.indexing import IndexSelResult, merge_sel_results + + +class AffineTransform(CoordinateTransform): + """Affine 2D transform wrapper.""" + + affine: Affine + xy_dims: tuple[str, str] + + def __init__( + self, + affine: Affine, + width: int, + height: int, + x_coord_name: Hashable = "xc", + y_coord_name: Hashable = "yc", + x_dim: str = "x", + y_dim: str = "y", + dtype: Any = np.dtype(np.float64), + ): + super().__init__( + (x_coord_name, y_coord_name), {x_dim: width, y_dim: height}, dtype=dtype + ) + self.affine = affine + + # array dimensions in reverse order (y = rows, x = cols) + self.xy_dims = self.dims[0], self.dims[1] + self.dims = self.dims[1], self.dims[0] + + def forward(self, dim_positions): + positions = tuple(dim_positions[dim] for dim in self.xy_dims) + x_labels, y_labels = self.affine * positions + + results = {} + for name, labels in zip(self.coord_names, [x_labels, y_labels]): + results[name] = labels + + return results + + def reverse(self, coord_labels): + labels = tuple(coord_labels[name] for name in self.coord_names) + x_positions, y_positions = ~self.affine * labels + + results = {} + for dim, positions in zip(self.xy_dims, [x_positions, y_positions]): + results[dim] = positions + + return results + + def equals(self, other): + if not isinstance(other, AffineTransform): + return False + return self.affine == other.affine and self.dim_size == other.dim_size + + +class AxisAffineTransform(CoordinateTransform): + """Axis-independent wrapper of an affine 2D transform with no skew/rotation.""" + + affine: Affine + is_xaxis: bool + coord_name: Hashable + dim: str + size: int + + def __init__( + self, + affine: Affine, + size: int, + coord_name: Hashable, + dim: str, + is_xaxis: bool, + dtype: Any = np.dtype(np.float64), + ): + assert affine.is_rectilinear and (affine.b == affine.d == 0) + + super().__init__((coord_name,), {dim: size}, dtype=dtype) + self.affine = affine + self.is_xaxis = is_xaxis + self.coord_name = coord_name + self.dim = dim + self.size = size + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + positions = np.asarray(dim_positions[self.dim]) + + if self.is_xaxis: + labels, _ = self.affine * (positions, np.zeros_like(positions)) + else: + _, labels = self.affine * (np.zeros_like(positions), positions) + + return {self.coord_name: labels} + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + labels = np.asarray(coord_labels[self.coord_name]) + + if self.is_xaxis: + positions, _ = ~self.affine * (labels, np.zeros_like(labels)) + else: + _, positions = ~self.affine * (np.zeros_like(labels), labels) + + return {self.dim: positions} + + def equals(self, other): + if not isinstance(other, AxisAffineTransform): + return False + + # only compare the affine parameters of the relevant axis + if self.is_xaxis: + affine_match = ( + self.affine.a == other.affine.a and self.affine.c == other.affine.c + ) + else: + affine_match = ( + self.affine.e == other.affine.e and self.affine.f == other.affine.f + ) + + return affine_match and self.size == other.size + + def generate_coords( + self, dims: tuple[str, ...] | None = None + ) -> dict[Hashable, Any]: + assert dims is None or dims == self.dims + return self.forward({self.dim: np.arange(self.size)}) + + def slice(self, slice: slice) -> "AxisAffineTransform": + start = max(slice.start or 0, 0) + stop = min(slice.stop or self.size, self.size) + step = slice.step or 1 + + # TODO: support reverse transform (i.e., start > stop)? + assert start < stop + + size = (stop - start) // step + scale = float(step) + + if self.is_xaxis: + affine = ( + self.affine * Affine.translation(start, 0.0) * Affine.scale(scale, 1.0) + ) + else: + affine = ( + self.affine * Affine.translation(0.0, start) * Affine.scale(1.0, scale) + ) + + return type(self)( + affine, + size, + self.coord_name, + self.dim, + is_xaxis=self.is_xaxis, + dtype=self.dtype, + ) + + +class AxisAffineTransformIndex(CoordinateTransformIndex): + """Axis-independent Xarray Index for an affine 2D transform with no + skew/rotation. + + For internal use only. + + This Index class provides specific behavior on top of + Xarray's `CoordinateTransformIndex`: + + - Data slicing computes a new affine transform and returns a new + `AxisAffineTransformIndex` object + + - Otherwise data selection creates and returns a new Xarray + `PandasIndex` object for non-scalar indexers + + - The index can be converted to a `pandas.Index` object (useful for Xarray + operations that don't work with Xarray indexes yet). + + """ + + axis_transform: AxisAffineTransform + dim: str + + def __init__(self, transform: AxisAffineTransform): + assert isinstance(transform, AxisAffineTransform) + super().__init__(transform) + self.axis_transform = transform + self.dim = transform.dim + + def isel( # type: ignore[override] + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> "AxisAffineTransformIndex | PandasIndex | None": + idxer = indexers[self.dim] + + # generate a new index with updated transform if a slice is given + if isinstance(idxer, slice): + return AxisAffineTransformIndex(self.axis_transform.slice(idxer)) + # no index for vectorized (fancy) indexing with n-dimensional Variable + elif isinstance(idxer, Variable) and idxer.ndim > 1: + return None + # no index for scalar value + elif np.ndim(idxer) == 0: + return None + # otherwise return a PandasIndex with values computed by forward transformation + else: + values = self.axis_transform.forward({self.dim: idxer})[ + self.axis_transform.coord_name + ] + if isinstance(idxer, Variable): + new_dim = idxer.dims[0] + else: + new_dim = self.dim + return PandasIndex(values, new_dim, coord_dtype=values.dtype) + + def sel(self, labels, method=None, tolerance=None): + coord_name = self.axis_transform.coord_name + label = labels[coord_name] + + if isinstance(label, slice): + if label.step is None: + # continuous interval slice indexing (preserves the index) + pos = self.transform.reverse( + {coord_name: np.array([label.start, label.stop])} + ) + pos = np.round(pos[self.dim]).astype("int") + new_start = max(pos[0], 0) + new_stop = min(pos[1], self.axis_transform.size) + return IndexSelResult({self.dim: slice(new_start, new_stop)}) + else: + # otherwise convert to basic (array) indexing + label = np.arange(label.start, label.stop, label.step) + + # support basic indexing (in the 1D case basic vs. vectorized indexing + # are pretty much similar) + unwrap_xr = False + if not isinstance(label, Variable | DataArray): + # basic indexing -> either scalar or 1-d array + try: + var = Variable("_", label) + except ValueError: + var = Variable((), label) + labels = {self.dim: var} + unwrap_xr = True + + result = super().sel(labels, method=method, tolerance=tolerance) + + if unwrap_xr: + dim_indexers = {self.dim: result.dim_indexers[self.dim].values} + result = IndexSelResult(dim_indexers) + + return result + + def to_pandas_index(self) -> pd.Index: + values = self.transform.generate_coords() + return pd.Index(values[self.dim]) + + +# The types of Xarray indexes that may be wrapped by RasterIndex +WrappedIndex = AxisAffineTransformIndex | PandasIndex | CoordinateTransformIndex +WrappedIndexCoords = Hashable | tuple[Hashable, Hashable] + + +def _filter_dim_indexers(index: WrappedIndex, indexers: Mapping) -> Mapping: + if isinstance(index, CoordinateTransformIndex): + dims = index.transform.dims + else: + # PandasIndex + dims = (str(index.dim),) + + return {dim: indexers[dim] for dim in dims if dim in indexers} + + +class RasterIndex(Index): + """Xarray index for raster coordinates. + + RasterIndex is itself a wrapper around one or more Xarray indexes associated + with either the raster x or y axis coordinate or both, depending on the + affine transformation and prior data selection (if any): + + - The affine transformation is not rectilinear or has rotation: this index + encapsulates a single `CoordinateTransformIndex` object for both the x and + y axis (2-dimensional) coordinates. + + - The affine transformation is rectilinear ands has no rotation: this index + encapsulates one or two index objects for either the x or y axis or both + (1-dimensional) coordinates. The index type is either a subclass of + `CoordinateTransformIndex` that supports slicing or `PandasIndex` (e.g., + after data selection at arbitrary locations). + + """ + + _wrapped_indexes: dict[WrappedIndexCoords, WrappedIndex] + + def __init__(self, indexes: Mapping[WrappedIndexCoords, WrappedIndex]): + idx_keys = list(indexes) + idx_vals = list(indexes.values()) + + # either one or the other configuration (dependent vs. independent x/y axes) + axis_dependent = ( + len(indexes) == 1 + and isinstance(idx_keys[0], tuple) + and isinstance(idx_vals[0], CoordinateTransformIndex) + ) + axis_independent = len(indexes) in (1, 2) and all( + isinstance(idx, AxisAffineTransformIndex | PandasIndex) for idx in idx_vals + ) + assert axis_dependent ^ axis_independent + + self._wrapped_indexes = dict(indexes) + + @classmethod + def from_transform( + cls, affine: Affine, width: int, height: int, x_dim: str = "x", y_dim: str = "y" + ) -> "RasterIndex": + indexes: dict[ + WrappedIndexCoords, AxisAffineTransformIndex | CoordinateTransformIndex + ] + + # pixel centered coordinates + affine = affine * Affine.translation(0.5, 0.5) + + if affine.is_rectilinear and affine.b == affine.d == 0: + x_transform = AxisAffineTransform(affine, width, "x", x_dim, is_xaxis=True) + y_transform = AxisAffineTransform( + affine, height, "y", y_dim, is_xaxis=False + ) + indexes = { + "x": AxisAffineTransformIndex(x_transform), + "y": AxisAffineTransformIndex(y_transform), + } + else: + xy_transform = AffineTransform( + affine, width, height, x_dim=x_dim, y_dim=y_dim + ) + indexes = {("x", "y"): CoordinateTransformIndex(xy_transform)} + + return cls(indexes) + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> "RasterIndex": + # TODO: compute bounds, resolution and affine transform from explicit coordinates. + raise NotImplementedError( + "Creating a RasterIndex from existing coordinates is not yet supported." + ) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> dict[Hashable, Variable]: + new_variables: dict[Hashable, Variable] = {} + + for index in self._wrapped_indexes.values(): + new_variables.update(index.create_variables()) + + return new_variables + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> "RasterIndex | None": + new_indexes: dict[WrappedIndexCoords, WrappedIndex] = {} + + for coord_names, index in self._wrapped_indexes.items(): + index_indexers = _filter_dim_indexers(index, indexers) + if not index_indexers: + # no selection to perform: simply propagate the index + # TODO: uncomment when https://github.com/pydata/xarray/issues/10063 is fixed + # new_indexes[coord_names] = index + ... + else: + new_index = index.isel(index_indexers) + if new_index is not None: + new_indexes[coord_names] = new_index + + if new_indexes: + # TODO: if there's only a single PandasIndex can we just return it? + # (maybe better to keep it wrapped if we plan to later make RasterIndex CRS-aware) + return RasterIndex(new_indexes) + else: + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + results = [] + + for coord_names, index in self._wrapped_indexes.items(): + if not isinstance(coord_names, tuple): + coord_names = (coord_names,) + index_labels = {k: v for k, v in labels if k in coord_names} + if index_labels: + results.append( + index.sel(index_labels, method=method, tolerance=tolerance) + ) + + return merge_sel_results(results) + + def equals(self, other: Index) -> bool: + if not isinstance(other, RasterIndex): + return False + if set(self._wrapped_indexes) != set(other._wrapped_indexes): + return False + + return all( + index.equals(other._wrapped_indexes[k]) # type: ignore[arg-type] + for k, index in self._wrapped_indexes.items() + ) + + def to_pandas_index(self) -> pd.Index: + # conversion is possible only if this raster index encapsulates + # exactly one AxisAffineTransformIndex or a PandasIndex associated + # to either the x or y axis (1-dimensional) coordinate. + if len(self._wrapped_indexes) == 1: + index = next(iter(self._wrapped_indexes.values())) + if isinstance(index, AxisAffineTransformIndex | PandasIndex): + return index.to_pandas_index() + + raise ValueError("Cannot convert RasterIndex to pandas.Index") + + def __repr__(self): + items: list[str] = [] + + for coord_names, index in self._wrapped_indexes.items(): + items += [repr(coord_names) + ":", textwrap.indent(repr(index), " ")] + + return "RasterIndex\n" + "\n".join(items)