diff --git a/Cargo.lock b/Cargo.lock index 2c34376..6877702 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2656,6 +2656,7 @@ dependencies = [ "num-bigint", "paste", "seq-macro", + "simdutf8", "snap", "thrift", "twox-hash 2.1.0", @@ -3507,6 +3508,7 @@ dependencies = [ "duckdb", "geoarrow-array 0.1.0-dev", "geojson", + "parquet", "pyo3", "pyo3-arrow", "pyo3-async-runtimes", diff --git a/Cargo.toml b/Cargo.toml index 05f7a72..82917e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ tokio = { version = "1.44.0", features = ["rt-multi-thread"] } pyo3-log = "0.12.1" tracing = "0.1.41" pyo3-object_store = "0.2.0" +parquet = "55.1.0" [build-dependencies] cargo-lock = "10" diff --git a/python/rustac/rustac.pyi b/python/rustac/rustac.pyi index d5bec1a..65f5fb3 100644 --- a/python/rustac/rustac.pyi +++ b/python/rustac/rustac.pyi @@ -350,6 +350,7 @@ async def search_to( filter: str | dict[str, Any] | None = None, query: dict[str, Any] | None = None, format: str | None = None, + parquet_compression: str | None = None, store: AnyObjectStore | None = None, use_duckdb: bool | None = None, ) -> int: @@ -389,6 +390,10 @@ async def search_to( It is recommended to use filter instead, if possible. format: The output format. If none, will be inferred from the outfile extension, and if that fails will fall back to compact JSON. + parquet_compression: If writing stac-geoparquet, sets the compression + algorithm. + https://docs.rs/parquet/latest/parquet/basic/enum.Compression.html + is a list of what's available. store: An optional [ObjectStore][] use_duckdb: Query with DuckDB. If None and the href has a 'parquet' or 'geoparquet' extension, will be set to True. Defaults @@ -428,6 +433,7 @@ async def write( value: dict[str, Any] | Sequence[dict[str, Any]], *, format: str | None = None, + parquet_compression: str | None = None, store: AnyObjectStore | None = None, ) -> dict[str, str] | None: """ @@ -439,6 +445,10 @@ async def write( can be a STAC dictionary or a list of items. format: The output format to write. If not provided, will be inferred from the href's extension. + parquet_compression: If writing stac-geoparquet, sets the compression + algorithm. + https://docs.rs/parquet/latest/parquet/basic/enum.Compression.html + is a list of what's available. store: The object store to use for writing. Returns: diff --git a/src/error.rs b/src/error.rs index aa44e2e..58595be 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,9 @@ pub enum Error { #[error(transparent)] StacDuckdb(#[from] stac_duckdb::Error), + #[error(transparent)] + Parquet(#[from] parquet::errors::ParquetError), + #[error(transparent)] TokioTaskJon(#[from] tokio::task::JoinError), } diff --git a/src/search.rs b/src/search.rs index d7dad03..a8f0723 100644 --- a/src/search.rs +++ b/src/search.rs @@ -58,7 +58,7 @@ pub fn search<'py>( } #[pyfunction] -#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, store=None, use_duckdb=None, **kwargs))] +#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, parquet_compression=None, store=None, use_duckdb=None, **kwargs))] #[allow(clippy::too_many_arguments)] pub fn search_to<'py>( py: Python<'py>, @@ -77,6 +77,7 @@ pub fn search_to<'py>( filter: Option, query: Option>, format: Option, + parquet_compression: Option, store: Option, use_duckdb: Option, kwargs: Option>, @@ -95,12 +96,18 @@ pub fn search_to<'py>( query, kwargs, )?; - let format = format + let mut format = format .map(|s| s.parse()) .transpose() .map_err(Error::from)? .or_else(|| Format::infer_from_href(&outfile)) .unwrap_or_default(); + if matches!(format, Format::Geoparquet(_)) { + if let Some(parquet_compression) = parquet_compression { + tracing::debug!("setting parquet compression: {parquet_compression}"); + format = Format::Geoparquet(Some(parquet_compression.parse().map_err(Error::from)?)); + } + } if use_duckdb .unwrap_or_else(|| matches!(Format::infer_from_href(&href), Some(Format::Geoparquet(_)))) { diff --git a/src/write.rs b/src/write.rs index 2923257..9db293c 100644 --- a/src/write.rs +++ b/src/write.rs @@ -6,12 +6,13 @@ use stac::{Item, ItemCollection}; use stac_io::{Format, StacStore}; #[pyfunction] -#[pyo3(signature = (href, value, *, format=None, store=None))] +#[pyo3(signature = (href, value, *, format=None, parquet_compression=None, store=None))] pub fn write<'py>( py: Python<'py>, href: String, value: Bound<'_, PyAny>, format: Option, + parquet_compression: Option, store: Option, ) -> PyResult> { let value: Value = pythonize::depythonize(&value)?; @@ -25,10 +26,17 @@ pub fn write<'py>( serde_json::from_value(value).map_err(Error::from)? }; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let format = format + let mut format = format .and_then(|f| f.parse::().ok()) .or_else(|| Format::infer_from_href(&href)) .unwrap_or_default(); + if matches!(format, Format::Geoparquet(_)) { + if let Some(parquet_compression) = parquet_compression { + tracing::debug!("setting parquet compression: {parquet_compression}"); + format = + Format::Geoparquet(Some(parquet_compression.parse().map_err(Error::from)?)); + } + } let (stac_store, path) = if let Some(store) = store { (StacStore::from(store.into_dyn()), None) } else { diff --git a/tests/test_write.py b/tests/test_write.py index eedb0af..70c2c25 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -5,6 +5,7 @@ import pyarrow.parquet import rustac import stac_geoparquet +from pyarrow.parquet import ParquetFile from rustac.store import LocalStore @@ -35,3 +36,14 @@ async def test_write_includes_type(tmp_path: Path, item: dict[str, Any]) -> None await rustac.write(str(tmp_path / "out.parquet"), [item]) data_frame = pandas.read_parquet(str(tmp_path / "out.parquet")) assert "type" in data_frame.columns + + +async def test_write_parquet_compression(tmp_path: Path, item: dict[str, Any]) -> None: + await rustac.write( + str(tmp_path / "out.parquet"), [item], parquet_compression="zstd(1)" + ) + parquet_file = ParquetFile(tmp_path / "out.parquet") + metadata = parquet_file.metadata + for row_group in range(metadata.num_row_groups): + for column in range(metadata.num_columns): + assert metadata.row_group(row_group).column(column).compression == "ZSTD"