Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Sync
run: uv sync --all-extras
run: scripts/install
- name: Pre-commit
run: uv run pre-commit run --all-files
- name: Lint
run: scripts/lint
- name: Test
run: uv run pytest tests -v
run: scripts/test
- name: Check docs
run: uv run mkdocs build --strict
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ Install via `pip` or `conda`:

## Development

Get [uv](https://docs.astral.sh/uv/getting-started/installation/), then:
1. Create a Python virtual environment
1. Get [uv](https://docs.astral.sh/uv/getting-started/installation/):
1. Get [cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html):

Then run:

```shell
git clone [email protected]:stac-utils/stac-geoparquet.git
cd stac-geoparquet
uv sync
scripts/install
uv run pre-commit install
uv run pytest
scripts/test
scripts/lint
```
5 changes: 5 additions & 0 deletions scripts/install
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env sh

set -e

uv sync --all-extras
5 changes: 5 additions & 0 deletions scripts/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env sh

set -e

uv run pytest tests -v
169 changes: 115 additions & 54 deletions stac_geoparquet/pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,25 @@
import itertools
import logging
import textwrap
from typing import Any
from typing import Any, Literal

import dateutil.tz
import fsspec
import orjson
import pandas as pd
import pyarrow.fs
import pypgstac.db
import pypgstac.hydration
import pystac
import shapely.wkb
import tqdm.auto
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_fixed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a new dependency? I don't see it in

dependencies = [
"ciso8601",
"deltalake",
"geopandas",
"packaging",
"pandas",
# Needed for RecordBatch.append_column
# Below 19 b/c https://github.com/apache/arrow/issues/45283
"pyarrow>=16,<19",
"pyproj",
"pystac",
"shapely",
"orjson",
'typing_extensions; python_version < "3.11"',
]


from stac_geoparquet import to_geodataframe
from stac_geoparquet.arrow import parse_stac_ndjson_to_parquet

logger = logging.getLogger(__name__)

EXPORT_FORMAT = Literal["geoparquet", "ndjson"]


def _pairwise(
iterable: collections.abc.Iterable,
Expand Down Expand Up @@ -148,32 +151,21 @@ def export_partition(
storage_options: dict[str, Any] | None = None,
rewrite: bool = False,
skip_empty_partitions: bool = False,
format: EXPORT_FORMAT = "geoparquet",
) -> str | None:
storage_options = storage_options or {}
az_fs = fsspec.filesystem(output_protocol, **storage_options)
if az_fs.exists(output_path) and not rewrite:
logger.debug("Path %s already exists.", output_path)
return output_path

db = pypgstac.db.PgstacDB(conninfo)
with db:
assert db.connection is not None
db.connection.execute("set statement_timeout = 300000;")
# logger.debug("Reading base item")
# TODO: proper escaping
base_item = db.query_one(
f"select * from collection_base_item('{self.collection_id}');"
)
records = list(db.query(query))
fs = fsspec.filesystem(output_protocol, **storage_options)

base_item, records = _enumerate_db_items(self.collection_id, conninfo, query)

if skip_empty_partitions and len(records) == 0:
logger.debug("No records found for query %s.", query)
return None

items = self.make_pgstac_items(records, base_item) # type: ignore[arg-type]
df = to_geodataframe(items)
filesystem = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(az_fs))
df.to_parquet(output_path, index=False, filesystem=filesystem)

logger.debug("Exporting %d items as %s to %s", len(items), format, output_path)
_write_ndjson(output_path, fs, items)
return output_path

def export_partition_for_endpoints(
Expand All @@ -187,6 +179,7 @@ def export_partition_for_endpoints(
total: int | None = None,
rewrite: bool = False,
skip_empty_partitions: bool = False,
format: EXPORT_FORMAT = "geoparquet",
) -> str | None:
"""
Export results for a pair of endpoints.
Expand All @@ -205,7 +198,9 @@ def export_partition_for_endpoints(
+ f"and datetime >= '{a.isoformat()}' and datetime < '{b.isoformat()}'"
)

partition_path = _build_output_path(output_path, part_number, total, a, b)
partition_path = _build_output_path(
output_path, part_number, total, a, b, format=format
)
return self.export_partition(
conninfo,
query,
Expand All @@ -214,6 +209,7 @@ def export_partition_for_endpoints(
storage_options=storage_options,
rewrite=rewrite,
skip_empty_partitions=skip_empty_partitions,
format=format,
)

def export_collection(
Expand All @@ -224,6 +220,7 @@ def export_collection(
storage_options: dict[str, Any],
rewrite: bool = False,
skip_empty_partitions: bool = False,
format: EXPORT_FORMAT = "geoparquet",
) -> list[str | None]:
base_query = textwrap.dedent(
f"""\
Expand All @@ -232,45 +229,77 @@ def export_collection(
where collection = '{self.collection_id}'
"""
)
if output_protocol:
output_path = f"{output_protocol}://{output_path}"

intermediate_path = f"/tmp/{self.collection_id}.ndjson"
results: list[str | None] = []
if not self.partition_frequency:
logger.info("Exporting single-partition collection %s", self.collection_id)
logger.info(
"Exporting single-partition collection %s to ndjson", self.collection_id
)
logger.debug("query=%s", base_query)
results = [
self.export_partition(
conninfo,
base_query,
output_protocol,
output_path,
storage_options=storage_options,
rewrite=rewrite,
)
]

# First write NDJSON to disk
self.export_partition(
conninfo,
base_query,
"file",
intermediate_path,
storage_options={"auto_mkdir": True},
rewrite=rewrite,
format="ndjson",
)
if output_protocol:
output_path = f"{output_protocol}://{output_path}.parquet"
logger.debug("Writing geoparquet to %s", output_path)
results.append(intermediate_path)
parse_stac_ndjson_to_parquet(
results,
output_path,
filesystem=fsspec.filesystem(output_protocol, **storage_options),
)
else:
endpoints = self.generate_endpoints()
total = len(endpoints)
if output_protocol:
output_path = f"{output_protocol}://{output_path}.parquet"
logger.info(
"Exporting %d partitions for collection %s", total, self.collection_id
)

results = []
for i, endpoint in tqdm.auto.tqdm(enumerate(endpoints), total=total):
results.append(
self.export_partition_for_endpoints(
endpoints=endpoint,
conninfo=conninfo,
output_protocol=output_protocol,
output_path=output_path,
storage_options=storage_options,
rewrite=rewrite,
skip_empty_partitions=skip_empty_partitions,
part_number=i,
total=total,
)
partition = self.export_partition_for_endpoints(
endpoints=endpoint,
conninfo=conninfo,
output_protocol="file",
output_path=intermediate_path,
storage_options={"auto_mkdir": True},
rewrite=rewrite,
skip_empty_partitions=skip_empty_partitions,
part_number=i,
total=total,
format="ndjson",
)
if partition:
results.append(partition)
partition_path = _build_output_path(
output_path,
i,
total,
endpoint[0],
endpoint[1],
format="geoparquet",
)
logger.debug("Writing geoparquet to %s", partition_path)
parse_stac_ndjson_to_parquet(
partition,
partition_path,
filesystem=fsspec.filesystem(
output_protocol, **storage_options
),
)

# delete every file in the results list
for result in results:
logger.debug("Cleaning up %s", result)
fsspec.filesystem("file").rm(result, recursive=True)

return results

Expand Down Expand Up @@ -340,20 +369,52 @@ def _build_output_path(
total: int | None,
start_datetime: datetime.datetime,
end_datetime: datetime.datetime,
format: EXPORT_FORMAT = "geoparquet",
) -> str:
a, b = start_datetime, end_datetime
base_output_path = base_output_path.rstrip("/")
file_extensions = {
"geoparquet": "parquet",
"ndjson": "ndjson",
}

if part_number is not None and total is not None:
output_path = (
f"{base_output_path}/part-{part_number:0{len(str(total * 10))}}_"
f"{a.isoformat()}_{b.isoformat()}.parquet"
f"{a.isoformat()}_{b.isoformat()}.{file_extensions[format]}"
)
else:
token = hashlib.md5(
"".join([a.isoformat(), b.isoformat()]).encode()
).hexdigest()
output_path = (
f"{base_output_path}/part-{token}_{a.isoformat()}_{b.isoformat()}.parquet"
)
output_path = f"{base_output_path}/part-{token}_{a.isoformat()}_{b.isoformat()}.{file_extensions[format]}"
return output_path


@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(2),
before_sleep=before_sleep_log(logger, logging.DEBUG))
def _enumerate_db_items(
collection_id: str, conninfo: str, query: str
) -> tuple[Any, list[Any]]:
db = pypgstac.db.PgstacDB(conninfo)
with db:
assert db.connection is not None
db.connection.execute("set statement_timeout = 300000;")
# logger.debug("Reading base item")
# TODO: proper escaping
base_item = db.query_one(
f"select * from collection_base_item('{collection_id}');"
)
records = list(db.query(query))
return base_item, records


def _write_ndjson(
output_path: str, fs: fsspec.AbstractFileSystem, items: list[dict]
) -> None:
with fs.open(output_path, "wb") as f:
for item in items:
f.write(orjson.dumps(item))
f.write(b"\n")
53 changes: 43 additions & 10 deletions tests/test_pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import dateutil
import fsspec
import pandas as pd
import pystac
import pytest
Expand All @@ -14,6 +15,19 @@

HERE = pathlib.Path(__file__).parent

@pytest.fixture
def sentinel2_collection_config() -> stac_geoparquet.pgstac_reader.CollectionConfig:
return stac_geoparquet.pgstac_reader.CollectionConfig(
collection_id="sentinel-2-l2a",
partition_frequency=None,
stac_api="https://planetarycomputer.microsoft.com/api/stac/v1",
should_inject_dynamic_properties=True,
render_config="assets=visual&asset_bidx=visual%7C1%2C2%2C3&nodata=0&format=png",
)

@pytest.fixture
def sentinel2_record():
return json.loads(HERE.joinpath("record_sentinel2_l2a.json").read_text())

@pytest.mark.vcr
@pytest.mark.skipif(
Expand Down Expand Up @@ -124,20 +138,15 @@ def test_naip_item():
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="vcr tests require python3.10 or higher"
)
def test_sentinel2_l2a():
record = json.loads(HERE.joinpath("record_sentinel2_l2a.json").read_text())
def test_sentinel2_l2a(
sentinel2_collection_config: stac_geoparquet.pgstac_reader.CollectionConfig,
sentinel2_record) -> None:
record = sentinel2_record
base_item = json.loads(HERE.joinpath("base_sentinel2_l2a.json").read_text())
record[3] = dateutil.parser.parse(record[3])
record[4] = dateutil.parser.parse(record[4])

config = stac_geoparquet.pgstac_reader.CollectionConfig(
collection_id="sentinel-2-l2a",
partition_frequency=None,
stac_api="https://planetarycomputer.microsoft.com/api/stac/v1",
should_inject_dynamic_properties=True,
render_config="assets=visual&asset_bidx=visual%7C1%2C2%2C3&nodata=0&format=png",
)
result = pystac.read_dict(config.make_pgstac_items([record], base_item)[0])
result = pystac.read_dict(sentinel2_collection_config.make_pgstac_items([record], base_item)[0])
expected = pystac.read_file(
"https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2A_MSIL2A_20150704T101006_R022_T35XQA_20210411T133707" # noqa: E501
)
Expand Down Expand Up @@ -199,3 +208,27 @@ def test_build_output_path(part_number, total, start_datetime, end_datetime, exp
base_output_path, part_number, total, start_datetime, end_datetime
)
assert result == expected

def test_write_ndjson(
tmp_path,
sentinel2_collection_config: stac_geoparquet.pgstac_reader.CollectionConfig,
sentinel2_record) -> None:
record = sentinel2_record
base_item = json.loads(HERE.joinpath("base_sentinel2_l2a.json").read_text())

items = sentinel2_collection_config.make_pgstac_items(
[record, record], base_item)
fs = fsspec.filesystem("file")
stac_geoparquet.pgstac_reader._write_ndjson(
tmp_path / "test.ndjson",
fs,
items
)
# check that the file has 2 lines
with fs.open(tmp_path / "test.ndjson") as f:
lines = f.readlines()
assert len(lines) == 2
# check that the first line is a valid json
json.loads(lines[0])
# check that the second line is a valid json
json.loads(lines[1])
Loading