diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index b646840..e664f36 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -23,12 +23,12 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Sync - run: uv sync --all-extras + run: scripts/install - name: Pre-commit run: uv run pre-commit run --all-files - name: Lint run: scripts/lint - name: Test - run: uv run pytest tests -v + run: scripts/test - name: Check docs run: uv run mkdocs build --strict diff --git a/README.md b/README.md index b5e1a31..207c4f9 100644 --- a/README.md +++ b/README.md @@ -30,13 +30,17 @@ Install via `pip` or `conda`: ## Development -Get [uv](https://docs.astral.sh/uv/getting-started/installation/), then: +1. Create a Python virtual environment +1. Get [uv](https://docs.astral.sh/uv/getting-started/installation/): +1. Get [cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html): + +Then run: ```shell git clone git@github.com:stac-utils/stac-geoparquet.git cd stac-geoparquet -uv sync +scripts/install uv run pre-commit install -uv run pytest +scripts/test scripts/lint ``` diff --git a/scripts/install b/scripts/install new file mode 100755 index 0000000..78b19bd --- /dev/null +++ b/scripts/install @@ -0,0 +1,5 @@ +#!/usr/bin/env sh + +set -e + +uv sync --all-extras diff --git a/scripts/test b/scripts/test new file mode 100755 index 0000000..fc2620e --- /dev/null +++ b/scripts/test @@ -0,0 +1,5 @@ +#!/usr/bin/env sh + +set -e + +uv run pytest tests -v diff --git a/stac_geoparquet/pgstac_reader.py b/stac_geoparquet/pgstac_reader.py index 665ac16..21e764e 100644 --- a/stac_geoparquet/pgstac_reader.py +++ b/stac_geoparquet/pgstac_reader.py @@ -7,22 +7,25 @@ import itertools import logging import textwrap -from typing import Any +from typing import Any, Literal import dateutil.tz import fsspec +import orjson import pandas as pd -import pyarrow.fs import pypgstac.db import pypgstac.hydration import pystac import shapely.wkb import tqdm.auto +from tenacity import before_sleep_log, retry, stop_after_attempt, wait_fixed -from stac_geoparquet import to_geodataframe +from stac_geoparquet.arrow import parse_stac_ndjson_to_parquet logger = logging.getLogger(__name__) +EXPORT_FORMAT = Literal["geoparquet", "ndjson"] + def _pairwise( iterable: collections.abc.Iterable, @@ -148,32 +151,21 @@ def export_partition( storage_options: dict[str, Any] | None = None, rewrite: bool = False, skip_empty_partitions: bool = False, + format: EXPORT_FORMAT = "geoparquet", ) -> str | None: storage_options = storage_options or {} - az_fs = fsspec.filesystem(output_protocol, **storage_options) - if az_fs.exists(output_path) and not rewrite: - logger.debug("Path %s already exists.", output_path) - return output_path - - db = pypgstac.db.PgstacDB(conninfo) - with db: - assert db.connection is not None - db.connection.execute("set statement_timeout = 300000;") - # logger.debug("Reading base item") - # TODO: proper escaping - base_item = db.query_one( - f"select * from collection_base_item('{self.collection_id}');" - ) - records = list(db.query(query)) + fs = fsspec.filesystem(output_protocol, **storage_options) + + base_item, records = _enumerate_db_items(self.collection_id, conninfo, query) if skip_empty_partitions and len(records) == 0: logger.debug("No records found for query %s.", query) return None items = self.make_pgstac_items(records, base_item) # type: ignore[arg-type] - df = to_geodataframe(items) - filesystem = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(az_fs)) - df.to_parquet(output_path, index=False, filesystem=filesystem) + + logger.debug("Exporting %d items as %s to %s", len(items), format, output_path) + _write_ndjson(output_path, fs, items) return output_path def export_partition_for_endpoints( @@ -187,6 +179,7 @@ def export_partition_for_endpoints( total: int | None = None, rewrite: bool = False, skip_empty_partitions: bool = False, + format: EXPORT_FORMAT = "geoparquet", ) -> str | None: """ Export results for a pair of endpoints. @@ -205,7 +198,9 @@ def export_partition_for_endpoints( + f"and datetime >= '{a.isoformat()}' and datetime < '{b.isoformat()}'" ) - partition_path = _build_output_path(output_path, part_number, total, a, b) + partition_path = _build_output_path( + output_path, part_number, total, a, b, format=format + ) return self.export_partition( conninfo, query, @@ -214,6 +209,7 @@ def export_partition_for_endpoints( storage_options=storage_options, rewrite=rewrite, skip_empty_partitions=skip_empty_partitions, + format=format, ) def export_collection( @@ -224,6 +220,7 @@ def export_collection( storage_options: dict[str, Any], rewrite: bool = False, skip_empty_partitions: bool = False, + format: EXPORT_FORMAT = "geoparquet", ) -> list[str | None]: base_query = textwrap.dedent( f"""\ @@ -232,45 +229,77 @@ def export_collection( where collection = '{self.collection_id}' """ ) - if output_protocol: - output_path = f"{output_protocol}://{output_path}" + intermediate_path = f"/tmp/{self.collection_id}.ndjson" + results: list[str | None] = [] if not self.partition_frequency: - logger.info("Exporting single-partition collection %s", self.collection_id) + logger.info( + "Exporting single-partition collection %s to ndjson", self.collection_id + ) logger.debug("query=%s", base_query) - results = [ - self.export_partition( - conninfo, - base_query, - output_protocol, - output_path, - storage_options=storage_options, - rewrite=rewrite, - ) - ] - + # First write NDJSON to disk + self.export_partition( + conninfo, + base_query, + "file", + intermediate_path, + storage_options={"auto_mkdir": True}, + rewrite=rewrite, + format="ndjson", + ) + if output_protocol: + output_path = f"{output_protocol}://{output_path}.parquet" + logger.debug("Writing geoparquet to %s", output_path) + results.append(intermediate_path) + parse_stac_ndjson_to_parquet( + results, + output_path, + filesystem=fsspec.filesystem(output_protocol, **storage_options), + ) else: endpoints = self.generate_endpoints() total = len(endpoints) + if output_protocol: + output_path = f"{output_protocol}://{output_path}.parquet" logger.info( "Exporting %d partitions for collection %s", total, self.collection_id ) - - results = [] for i, endpoint in tqdm.auto.tqdm(enumerate(endpoints), total=total): - results.append( - self.export_partition_for_endpoints( - endpoints=endpoint, - conninfo=conninfo, - output_protocol=output_protocol, - output_path=output_path, - storage_options=storage_options, - rewrite=rewrite, - skip_empty_partitions=skip_empty_partitions, - part_number=i, - total=total, - ) + partition = self.export_partition_for_endpoints( + endpoints=endpoint, + conninfo=conninfo, + output_protocol="file", + output_path=intermediate_path, + storage_options={"auto_mkdir": True}, + rewrite=rewrite, + skip_empty_partitions=skip_empty_partitions, + part_number=i, + total=total, + format="ndjson", ) + if partition: + results.append(partition) + partition_path = _build_output_path( + output_path, + i, + total, + endpoint[0], + endpoint[1], + format="geoparquet", + ) + logger.debug("Writing geoparquet to %s", partition_path) + parse_stac_ndjson_to_parquet( + partition, + partition_path, + filesystem=fsspec.filesystem( + output_protocol, **storage_options + ), + ) + + # delete every file in the results list + for result in results: + logger.debug("Cleaning up %s", result) + fsspec.filesystem("file").rm(result, recursive=True) return results @@ -340,20 +369,52 @@ def _build_output_path( total: int | None, start_datetime: datetime.datetime, end_datetime: datetime.datetime, + format: EXPORT_FORMAT = "geoparquet", ) -> str: a, b = start_datetime, end_datetime base_output_path = base_output_path.rstrip("/") + file_extensions = { + "geoparquet": "parquet", + "ndjson": "ndjson", + } if part_number is not None and total is not None: output_path = ( f"{base_output_path}/part-{part_number:0{len(str(total * 10))}}_" - f"{a.isoformat()}_{b.isoformat()}.parquet" + f"{a.isoformat()}_{b.isoformat()}.{file_extensions[format]}" ) else: token = hashlib.md5( "".join([a.isoformat(), b.isoformat()]).encode() ).hexdigest() - output_path = ( - f"{base_output_path}/part-{token}_{a.isoformat()}_{b.isoformat()}.parquet" - ) + output_path = f"{base_output_path}/part-{token}_{a.isoformat()}_{b.isoformat()}.{file_extensions[format]}" return output_path + + +@retry( + stop=stop_after_attempt(3), + wait=wait_fixed(2), + before_sleep=before_sleep_log(logger, logging.DEBUG)) +def _enumerate_db_items( + collection_id: str, conninfo: str, query: str +) -> tuple[Any, list[Any]]: + db = pypgstac.db.PgstacDB(conninfo) + with db: + assert db.connection is not None + db.connection.execute("set statement_timeout = 300000;") + # logger.debug("Reading base item") + # TODO: proper escaping + base_item = db.query_one( + f"select * from collection_base_item('{collection_id}');" + ) + records = list(db.query(query)) + return base_item, records + + +def _write_ndjson( + output_path: str, fs: fsspec.AbstractFileSystem, items: list[dict] +) -> None: + with fs.open(output_path, "wb") as f: + for item in items: + f.write(orjson.dumps(item)) + f.write(b"\n") diff --git a/tests/test_pgstac_reader.py b/tests/test_pgstac_reader.py index 9fa6fdb..3cc2bd8 100644 --- a/tests/test_pgstac_reader.py +++ b/tests/test_pgstac_reader.py @@ -4,6 +4,7 @@ import sys import dateutil +import fsspec import pandas as pd import pystac import pytest @@ -14,6 +15,19 @@ HERE = pathlib.Path(__file__).parent +@pytest.fixture +def sentinel2_collection_config() -> stac_geoparquet.pgstac_reader.CollectionConfig: + return stac_geoparquet.pgstac_reader.CollectionConfig( + collection_id="sentinel-2-l2a", + partition_frequency=None, + stac_api="https://planetarycomputer.microsoft.com/api/stac/v1", + should_inject_dynamic_properties=True, + render_config="assets=visual&asset_bidx=visual%7C1%2C2%2C3&nodata=0&format=png", + ) + +@pytest.fixture +def sentinel2_record(): + return json.loads(HERE.joinpath("record_sentinel2_l2a.json").read_text()) @pytest.mark.vcr @pytest.mark.skipif( @@ -124,20 +138,15 @@ def test_naip_item(): @pytest.mark.skipif( sys.version_info < (3, 10), reason="vcr tests require python3.10 or higher" ) -def test_sentinel2_l2a(): - record = json.loads(HERE.joinpath("record_sentinel2_l2a.json").read_text()) +def test_sentinel2_l2a( + sentinel2_collection_config: stac_geoparquet.pgstac_reader.CollectionConfig, + sentinel2_record) -> None: + record = sentinel2_record base_item = json.loads(HERE.joinpath("base_sentinel2_l2a.json").read_text()) record[3] = dateutil.parser.parse(record[3]) record[4] = dateutil.parser.parse(record[4]) - config = stac_geoparquet.pgstac_reader.CollectionConfig( - collection_id="sentinel-2-l2a", - partition_frequency=None, - stac_api="https://planetarycomputer.microsoft.com/api/stac/v1", - should_inject_dynamic_properties=True, - render_config="assets=visual&asset_bidx=visual%7C1%2C2%2C3&nodata=0&format=png", - ) - result = pystac.read_dict(config.make_pgstac_items([record], base_item)[0]) + result = pystac.read_dict(sentinel2_collection_config.make_pgstac_items([record], base_item)[0]) expected = pystac.read_file( "https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2A_MSIL2A_20150704T101006_R022_T35XQA_20210411T133707" # noqa: E501 ) @@ -199,3 +208,27 @@ def test_build_output_path(part_number, total, start_datetime, end_datetime, exp base_output_path, part_number, total, start_datetime, end_datetime ) assert result == expected + +def test_write_ndjson( + tmp_path, + sentinel2_collection_config: stac_geoparquet.pgstac_reader.CollectionConfig, + sentinel2_record) -> None: + record = sentinel2_record + base_item = json.loads(HERE.joinpath("base_sentinel2_l2a.json").read_text()) + + items = sentinel2_collection_config.make_pgstac_items( + [record, record], base_item) + fs = fsspec.filesystem("file") + stac_geoparquet.pgstac_reader._write_ndjson( + tmp_path / "test.ndjson", + fs, + items + ) + # check that the file has 2 lines + with fs.open(tmp_path / "test.ndjson") as f: + lines = f.readlines() + assert len(lines) == 2 + # check that the first line is a valid json + json.loads(lines[0]) + # check that the second line is a valid json + json.loads(lines[1])