diff --git a/pyproject.toml b/pyproject.toml index 319914c..1d5e371 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dynamic = ["version", "description"] requires-python = ">=3.8" dependencies = [ "ciso8601", + "deltalake", "geopandas", "packaging", "pandas", diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index 2fee475..0cb25e8 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import os from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, Optional, Union +from typing import Any, Iterable, Iterator import pyarrow as pa @@ -11,10 +13,10 @@ def parse_stac_items_to_arrow( - items: Iterable[Dict[str, Any]], + items: Iterable[dict[str, Any]], *, chunk_size: int = 8192, - schema: Optional[Union[pa.Schema, InferredSchema]] = None, + schema: pa.Schema | InferredSchema | None = None, ) -> Iterable[pa.RecordBatch]: """Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`. @@ -51,11 +53,11 @@ def parse_stac_items_to_arrow( def parse_stac_ndjson_to_arrow( - path: Union[str, Path, Iterable[Union[str, Path]]], + path: str | Path | Iterable[str | Path], *, chunk_size: int = 65536, - schema: Optional[pa.Schema] = None, - limit: Optional[int] = None, + schema: pa.Schema | None = None, + limit: int | None = None, ) -> Iterator[pa.RecordBatch]: """ Convert one or more newline-delimited JSON STAC files to a generator of Arrow @@ -83,6 +85,7 @@ def parse_stac_ndjson_to_arrow( if schema is None: inferred_schema = InferredSchema() inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) + inferred_schema.manual_updates() yield from parse_stac_ndjson_to_arrow( path, chunk_size=chunk_size, schema=inferred_schema ) @@ -103,7 +106,7 @@ def stac_table_to_items(table: pa.Table) -> Iterable[dict]: def stac_table_to_ndjson( - table: pa.Table, dest: Union[str, Path, os.PathLike[bytes]] + table: pa.Table, dest: str | Path | os.PathLike[bytes] ) -> None: """Write a STAC Table to a newline-delimited JSON file.""" for batch in table.to_batches(): @@ -112,7 +115,7 @@ def stac_table_to_ndjson( def stac_items_to_arrow( - items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None + items: Iterable[dict[str, Any]], *, schema: pa.Schema | None = None ) -> pa.RecordBatch: """Convert dicts representing STAC Items to Arrow diff --git a/stac_geoparquet/arrow/_delta_lake.py b/stac_geoparquet/arrow/_delta_lake.py new file mode 100644 index 0000000..d45c4bc --- /dev/null +++ b/stac_geoparquet/arrow/_delta_lake.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import itertools +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterable + +import pyarrow as pa +from deltalake import write_deltalake + +from stac_geoparquet.arrow._api import parse_stac_ndjson_to_arrow +from stac_geoparquet.arrow._to_parquet import create_geoparquet_metadata + +if TYPE_CHECKING: + from deltalake import DeltaTable + + +def parse_stac_ndjson_to_delta_lake( + input_path: str | Path | Iterable[str | Path], + table_or_uri: str | Path | DeltaTable, + *, + chunk_size: int = 65536, + schema: pa.Schema | None = None, + limit: int | None = None, + **kwargs: Any, +) -> None: + batches_iter = parse_stac_ndjson_to_arrow( + input_path, chunk_size=chunk_size, schema=schema, limit=limit + ) + first_batch = next(batches_iter) + schema = first_batch.schema.with_metadata( + create_geoparquet_metadata(pa.Table.from_batches([first_batch])) + ) + combined_iter = itertools.chain([first_batch], batches_iter) + write_deltalake(table_or_uri, combined_iter, schema=schema, engine="rust", **kwargs) diff --git a/stac_geoparquet/arrow/_schema/models.py b/stac_geoparquet/arrow/_schema/models.py index 17fd169..4043ada 100644 --- a/stac_geoparquet/arrow/_schema/models.py +++ b/stac_geoparquet/arrow/_schema/models.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Sequence, Union +from typing import Any, Iterable, Sequence import pyarrow as pa @@ -27,10 +29,10 @@ def __init__(self) -> None: def update_from_json( self, - path: Union[str, Path, Iterable[Union[str, Path]]], + path: str | Path | Iterable[str | Path], *, chunk_size: int = 65536, - limit: Optional[int] = None, + limit: int | None = None, ) -> None: """ Update this inferred schema from one or more newline-delimited JSON STAC files. @@ -45,7 +47,7 @@ def update_from_json( for batch in read_json_chunked(path, chunk_size=chunk_size, limit=limit): self.update_from_items(batch) - def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None: + def update_from_items(self, items: Sequence[dict[str, Any]]) -> None: """Update this inferred schema from a sequence of STAC Items.""" self.count += len(items) current_schema = StacJsonBatch.from_dicts(items, schema=None).inner.schema @@ -53,3 +55,49 @@ def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None: [self.inner, current_schema], promote_options="permissive" ) self.inner = new_schema + + def manual_updates(self) -> None: + schema = self.inner + properties_field = schema.field("properties") + properties_schema = pa.schema(properties_field.type) + + # The datetime column can be inferred as `null` in the case of a Collection with + # start_datetime and end_datetime. But `null` is incompatible with Delta Lake, + # so we coerce to a Timestamp type. + if pa.types.is_null(properties_schema.field("datetime").type): + field_idx = properties_schema.get_field_index("datetime") + properties_schema = properties_schema.set( + field_idx, + properties_schema.field(field_idx).with_type( + pa.timestamp("us", tz="UTC") + ), + ) + + if "proj:epsg" in properties_schema.names and pa.types.is_null( + properties_schema.field("proj:epsg").type + ): + field_idx = properties_schema.get_field_index("proj:epsg") + properties_schema = properties_schema.set( + field_idx, + properties_schema.field(field_idx).with_type(pa.int64()), + ) + + if "proj:wkt2" in properties_schema.names and pa.types.is_null( + properties_schema.field("proj:wkt2").type + ): + field_idx = properties_schema.get_field_index("proj:wkt2") + properties_schema = properties_schema.set( + field_idx, + properties_schema.field(field_idx).with_type(pa.string()), + ) + + # Note: proj:projjson can also be null, but we don't have a type we can cast + # that to. + + properties_idx = schema.get_field_index("properties") + updated_schema = schema.set( + properties_idx, + properties_field.with_type(pa.struct(properties_schema)), + ) + + self.inner = updated_schema diff --git a/stac_geoparquet/arrow/_to_parquet.py b/stac_geoparquet/arrow/_to_parquet.py index bdb6ba0..bce3d96 100644 --- a/stac_geoparquet/arrow/_to_parquet.py +++ b/stac_geoparquet/arrow/_to_parquet.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Iterable import pyarrow as pa import pyarrow.parquet as pq @@ -11,11 +13,12 @@ def parse_stac_ndjson_to_parquet( - input_path: Union[str, Path, Iterable[Union[str, Path]]], - output_path: Union[str, Path], + input_path: str | Path | Iterable[str | Path], + output_path: str | Path, *, chunk_size: int = 65536, - schema: Optional[Union[pa.Schema, InferredSchema]] = None, + schema: pa.Schema | InferredSchema | None = None, + limit: int | None = None, **kwargs: Any, ) -> None: """Convert one or more newline-delimited JSON STAC files to GeoParquet @@ -32,11 +35,11 @@ def parse_stac_ndjson_to_parquet( """ batches_iter = parse_stac_ndjson_to_arrow( - input_path, chunk_size=chunk_size, schema=schema + input_path, chunk_size=chunk_size, schema=schema, limit=limit ) first_batch = next(batches_iter) schema = first_batch.schema.with_metadata( - _create_geoparquet_metadata(pa.Table.from_batches([first_batch])) + create_geoparquet_metadata(pa.Table.from_batches([first_batch])) ) with pq.ParquetWriter(output_path, schema, **kwargs) as writer: writer.write_batch(first_batch) @@ -54,13 +57,13 @@ def to_parquet(table: pa.Table, where: Any, **kwargs: Any) -> None: where: The destination for saving. """ metadata = table.schema.metadata or {} - metadata.update(_create_geoparquet_metadata(table)) + metadata.update(create_geoparquet_metadata(table)) table = table.replace_schema_metadata(metadata) pq.write_table(table, where, **kwargs) -def _create_geoparquet_metadata(table: pa.Table) -> dict[bytes, bytes]: +def create_geoparquet_metadata(table: pa.Table) -> dict[bytes, bytes]: # TODO: include bbox of geometries column_meta = { "encoding": "WKB", @@ -77,7 +80,7 @@ def _create_geoparquet_metadata(table: pa.Table) -> dict[bytes, bytes]: } }, } - geo_meta: Dict[str, Any] = { + geo_meta: dict[str, Any] = { "version": "1.1.0-dev", "columns": {"geometry": column_meta}, "primary_column": "geometry", diff --git a/stac_geoparquet/cli.py b/stac_geoparquet/cli.py index bbb2403..6e350e6 100644 --- a/stac_geoparquet/cli.py +++ b/stac_geoparquet/cli.py @@ -1,15 +1,16 @@ +from __future__ import annotations + import argparse import logging import sys import os -from typing import List, Optional from stac_geoparquet import pc_runner logger = logging.getLogger("stac_geoparquet.pgstac_reader") -def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: +def parse_args(args: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--output-protocol", @@ -90,7 +91,7 @@ def setup_logging() -> None: } -def main(inp: Optional[List[str]] = None) -> int: +def main(inp: list[str] | None = None) -> int: import azure.data.tables args = parse_args(inp) diff --git a/stac_geoparquet/json_reader.py b/stac_geoparquet/json_reader.py index 62589d7..14c186c 100644 --- a/stac_geoparquet/json_reader.py +++ b/stac_geoparquet/json_reader.py @@ -1,7 +1,9 @@ """Return an iterator of items from an ndjson, a json array of items, or a featurecollection of items.""" +from __future__ import annotations + from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Sequence, Union +from typing import Any, Iterable, Sequence import orjson @@ -9,8 +11,8 @@ def read_json( - path: Union[str, Path, Iterable[Union[str, Path]]], -) -> Iterable[Dict[str, Any]]: + path: str | Path | Iterable[str | Path], +) -> Iterable[dict[str, Any]]: """Read a json or ndjson file.""" if isinstance(path, (str, Path)): path = [path] @@ -39,10 +41,10 @@ def read_json( def read_json_chunked( - path: Union[str, Path, Iterable[Union[str, Path]]], + path: str | Path | Iterable[str | Path], chunk_size: int, *, - limit: Optional[int] = None, -) -> Iterable[Sequence[Dict[str, Any]]]: + limit: int | None = None, +) -> Iterable[Sequence[dict[str, Any]]]: """Read from a JSON or NDJSON file in chunks of `chunk_size`.""" return batched_iter(read_json(path), chunk_size, limit=limit) diff --git a/tests/test_delta_lake.py b/tests/test_delta_lake.py new file mode 100644 index 0000000..4bf0124 --- /dev/null +++ b/tests/test_delta_lake.py @@ -0,0 +1,46 @@ +import json +from pathlib import Path + +import pytest +from deltalake import DeltaTable + +from stac_geoparquet.arrow import stac_table_to_items +from stac_geoparquet.arrow._delta_lake import parse_stac_ndjson_to_delta_lake + +from .json_equals import assert_json_value_equal + +HERE = Path(__file__).parent + +TEST_COLLECTIONS = [ + "3dep-lidar-copc", + # "3dep-lidar-dsm", + "cop-dem-glo-30", + "io-lulc-annual-v02", + # "io-lulc", + "landsat-c2-l1", + "landsat-c2-l2", + "naip", + "planet-nicfi-analytic", + "sentinel-1-rtc", + "sentinel-2-l2a", + "us-census", +] + + +@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS) +def test_round_trip_via_delta_lake(collection_id: str, tmp_path: Path): + path = HERE / "data" / f"{collection_id}-pc.json" + out_path = tmp_path / collection_id + parse_stac_ndjson_to_delta_lake(path, out_path) + + # Read back into table and convert to json + dt = DeltaTable(out_path) + table = dt.to_pyarrow_table() + items_result = list(stac_table_to_items(table)) + + # Compare with original json + with open(HERE / "data" / f"{collection_id}-pc.json") as f: + items = json.load(f) + + for result, expected in zip(items_result, items): + assert_json_value_equal(result, expected, precision=0)