diff --git a/CHANGELOG.md b/CHANGELOG.md index 897badb..df0bf7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] +### Added + +- Construct `stac_api::Search` (moved from `stac_api` crate) ([#81](https://github.com/stac-utils/stacrs/pull/81)) + ### Fixed - Swallow broken pipe errors ([#73](https://github.com/stac-utils/stacrs/pull/73)) diff --git a/Cargo.lock b/Cargo.lock index 3362b95..35e7864 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2695,7 +2695,7 @@ dependencies = [ [[package]] name = "pgstac" version = "0.3.0" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "serde", "serde_json", @@ -3398,9 +3398,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dade4812df5c384711475be5fcd8c162555352945401aed22a35bffeab61f657" +checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" dependencies = [ "bitflags 2.9.0", "errno", @@ -3712,7 +3712,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stac" version = "0.12.0" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "arrow-array", "arrow-cast", @@ -3743,7 +3743,7 @@ dependencies = [ [[package]] name = "stac-api" version = "0.7.1" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "async-stream", "chrono", @@ -3752,8 +3752,6 @@ dependencies = [ "geo", "geojson", "http", - "pyo3", - "pythonize", "reqwest", "serde", "serde_json", @@ -3769,7 +3767,7 @@ dependencies = [ [[package]] name = "stac-cli" version = "0.5.3" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "anyhow", "axum", @@ -3787,7 +3785,7 @@ dependencies = [ [[package]] name = "stac-derive" version = "0.2.0" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "quote", "syn 2.0.100", @@ -3796,7 +3794,7 @@ dependencies = [ [[package]] name = "stac-duckdb" version = "0.1.1" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "arrow", "chrono", @@ -3814,7 +3812,7 @@ dependencies = [ [[package]] name = "stac-server" version = "0.3.4" -source = "git+https://github.com/stac-utils/stac-rs?branch=main#83fd636ef3b21d65761a21228a1ff2ca094b5090" +source = "git+https://github.com/stac-utils/stac-rs?branch=main#e59405cf4d566218cc9fa5fef52b21d96244ef5f" dependencies = [ "axum", "bb8", diff --git a/Cargo.toml b/Cargo.toml index 3674b0b..9e4562d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,6 @@ stac = { features = [ ], git = "https://github.com/stac-utils/stac-rs", branch = "main" } stac-api = { features = [ "client", - "python", ], git = "https://github.com/stac-utils/stac-rs", branch = "main" } stac-cli = { git = "https://github.com/stac-utils/stac-rs", features = [ "pgstac", diff --git a/src/duckdb.rs b/src/duckdb.rs index 71382a6..e42ec86 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -1,4 +1,7 @@ -use crate::Result; +use crate::{ + search::{PySortby, StringOrDict, StringOrList}, + Result, +}; use pyo3::{ exceptions::PyException, prelude::*, @@ -6,7 +9,6 @@ use pyo3::{ IntoPyObjectExt, }; use pyo3_arrow::PyTable; -use stac_api::python::{StringOrDict, StringOrList}; use stac_duckdb::{Client, Config}; use std::sync::Mutex; @@ -16,11 +18,24 @@ pub struct DuckdbClient(Mutex); #[pymethods] impl DuckdbClient { #[new] - #[pyo3(signature = (use_s3_credential_chain=true, use_hive_partitioning=false))] - fn new(use_s3_credential_chain: bool, use_hive_partitioning: bool) -> Result { + #[pyo3(signature = (*, use_s3_credential_chain=true, use_azure_credential_chain=true, use_httpfs=true, use_hive_partitioning=false, install_extensions=true, custom_extension_repository=None, extension_directory=None))] + fn new( + use_s3_credential_chain: bool, + use_azure_credential_chain: bool, + use_httpfs: bool, + use_hive_partitioning: bool, + install_extensions: bool, + custom_extension_repository: Option, + extension_directory: Option, + ) -> Result { let config = Config { use_s3_credential_chain, + use_azure_credential_chain, + use_httpfs, use_hive_partitioning, + install_extensions, + custom_extension_repository, + extension_directory, convert_wkb: true, }; let client = Client::with_config(config)?; @@ -40,12 +55,12 @@ impl DuckdbClient { datetime: Option, include: Option, exclude: Option, - sortby: Option, + sortby: Option>, filter: Option, query: Option>, kwargs: Option>, ) -> Result> { - let search = stac_api::python::search( + let search = crate::search::build( intersects, ids, collections, @@ -84,12 +99,12 @@ impl DuckdbClient { datetime: Option, include: Option, exclude: Option, - sortby: Option, + sortby: Option>, filter: Option, query: Option>, kwargs: Option>, ) -> Result { - let search = stac_api::python::search( + let search = crate::search::build( intersects, ids, collections, diff --git a/src/search.rs b/src/search.rs index 339e2f9..af8a3d9 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,10 +1,10 @@ use crate::{Error, Json, Result}; -use pyo3::{prelude::*, types::PyDict}; +use geojson::Geometry; +use pyo3::prelude::*; +use pyo3::{exceptions::PyValueError, types::PyDict, Bound, FromPyObject, PyErr, PyResult}; +use stac::Bbox; use stac::Format; -use stac_api::{ - python::{StringOrDict, StringOrList}, - Search, -}; +use stac_api::{Fields, Filter, Items, Search, Sortby}; #[pyfunction] #[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, use_duckdb=None, **kwargs))] @@ -21,13 +21,13 @@ pub fn search<'py>( datetime: Option, include: Option, exclude: Option, - sortby: Option, + sortby: Option>, filter: Option, query: Option>, use_duckdb: Option, kwargs: Option>, ) -> PyResult> { - let search = stac_api::python::search( + let search = build( intersects, ids, collections, @@ -72,7 +72,7 @@ pub fn search_to<'py>( datetime: Option, include: Option, exclude: Option, - sortby: Option, + sortby: Option>, filter: Option, query: Option>, format: Option, @@ -80,7 +80,7 @@ pub fn search_to<'py>( use_duckdb: Option, kwargs: Option>, ) -> PyResult> { - let search = stac_api::python::search( + let search = build( intersects, ids, collections, @@ -150,3 +150,144 @@ async fn search_api( let value = stac_api::client::search(&href, search, max_items).await?; Ok(value) } + +/// Creates a [Search] from Python arguments. +#[allow(clippy::too_many_arguments)] +pub fn build<'py>( + intersects: Option>, + ids: Option, + collections: Option, + limit: Option, + bbox: Option>, + datetime: Option, + include: Option, + exclude: Option, + sortby: Option>, + filter: Option>, + query: Option>, + kwargs: Option>, +) -> PyResult { + let mut fields = Fields::default(); + if let Some(include) = include { + fields.include = include.into(); + } + if let Some(exclude) = exclude { + fields.exclude = exclude.into(); + } + let fields = if fields.include.is_empty() && fields.exclude.is_empty() { + None + } else { + Some(fields) + }; + let query = query + .map(|query| pythonize::depythonize(&query)) + .transpose()?; + let bbox = bbox.map(Bbox::try_from).transpose().map_err(Error::from)?; + let sortby: Vec = sortby + .map(|sortby| match sortby { + PySortby::ListOfDicts(list) => list + .into_iter() + .map(|d| pythonize::depythonize(&d).map_err(Error::from)) + .collect::>>(), + PySortby::ListOfStrings(list) => list + .into_iter() + .map(|s| Ok(s.parse().unwrap())) // infallible + .collect::>>(), + PySortby::String(s) => Ok(vec![s.parse().unwrap()]), + }) + .transpose()? + .unwrap_or_default(); + let filter = filter + .map(|filter| match filter { + StringOrDict::Dict(cql_json) => pythonize::depythonize(&cql_json).map(Filter::Cql2Json), + StringOrDict::String(cql2_text) => Ok(Filter::Cql2Text(cql2_text)), + }) + .transpose()?; + let filter = filter + .map(|filter| filter.into_cql2_json()) + .transpose() + .map_err(Error::from)?; + let mut items = Items { + limit, + bbox, + datetime, + query, + fields, + sortby, + filter, + ..Default::default() + }; + if let Some(kwargs) = kwargs { + items.additional_fields = pythonize::depythonize(&kwargs)?; + } + + let intersects = intersects + .map(|intersects| match intersects { + StringOrDict::Dict(json) => pythonize::depythonize(&json) + .map_err(PyErr::from) + .and_then(|json| { + Geometry::from_json_object(json) + .map_err(|err| PyValueError::new_err(err.to_string())) + }), + StringOrDict::String(s) => s + .parse::() + .map_err(|err| PyValueError::new_err(err.to_string())), + }) + .transpose()?; + let ids = ids.map(|ids| ids.into()).unwrap_or_default(); + let collections = collections.map(|ids| ids.into()).unwrap_or_default(); + Ok(Search { + items, + intersects, + ids, + collections, + }) +} + +/// A string or dictionary. +/// +/// Used for the CQL2 filter argument and for intersects. +#[derive(Debug, FromPyObject)] +pub enum StringOrDict<'py> { + /// Text + String(String), + + /// Json + Dict(Bound<'py, PyDict>), +} + +/// A string or a list. +/// +/// Used for collections, ids, etc. +#[derive(Debug, FromPyObject)] +pub enum StringOrList { + /// A string. + String(String), + + /// A list. + List(Vec), +} + +/// A sortby structure. +/// +/// This can be a string, a list of strings, or a list of dictionaries. +#[derive(Debug, FromPyObject)] +pub enum PySortby<'py> { + /// A string. + String(String), + + /// A list. + ListOfStrings(Vec), + + /// A list. + ListOfDicts(Vec>), +} + +impl From for Vec { + fn from(value: StringOrList) -> Vec { + match value { + StringOrList::List(list) => list, + StringOrList::String(s) => vec![s], + } + } +} diff --git a/stacrs.pyi b/stacrs.pyi index c04f81b..b3425c4 100644 --- a/stacrs.pyi +++ b/stacrs.pyi @@ -6,15 +6,30 @@ class DuckdbClient: """A client for querying stac-geoparquet with DuckDB.""" def __init__( - self, use_s3_credential_chain: bool = True, use_hive_partitioning: bool = False + self, + *, + use_s3_credential_chain: bool = True, + use_azure_credential_chain: bool = True, + use_httpfs: bool = True, + use_hive_partitioning: bool = False, + install_extensions: bool = True, + custom_extension_repository: str | None = None, + extension_directory: str | None = None, ) -> None: """Creates a new duckdb client. Args: use_s3_credential_chain: If true, configures DuckDB to correctly handle s3:// urls. + use_azure_credential_chain: If true, configures DuckDB to correctly + handle azure urls. + use_https_credential_chain: If true, configures DuckDB to correctly + handle https urls. use_hive_partitioning: If true, enables queries on hive partitioned geoparquet files. + install_extensions: If true, installs extensions before loading them. + custom_extension_repository: A custom extension repository to use. + extension_directory: A non-standard extension directory to use. """ def search( @@ -30,7 +45,7 @@ class DuckdbClient: datetime: Optional[str] = None, include: Optional[str | list[str]] = None, exclude: Optional[str | list[str]] = None, - sortby: Optional[str | list[str]] = None, + sortby: Optional[str | list[str | dict[str, str]]] = None, filter: Optional[str | dict[str, Any]] = None, query: Optional[dict[str, Any]] = None, **kwargs: str, @@ -79,7 +94,7 @@ class DuckdbClient: datetime: Optional[str] = None, include: Optional[str | list[str]] = None, exclude: Optional[str | list[str]] = None, - sortby: Optional[str | list[str]] = None, + sortby: Optional[str | list[str | dict[str, str]]] = None, filter: Optional[str | dict[str, Any]] = None, query: Optional[dict[str, Any]] = None, **kwargs: str, @@ -260,7 +275,7 @@ async def search( datetime: Optional[str] = None, include: Optional[str | list[str]] = None, exclude: Optional[str | list[str]] = None, - sortby: Optional[str | list[str]] = None, + sortby: Optional[str | list[str | dict[str, str]]] = None, filter: Optional[str | dict[str, Any]] = None, query: Optional[dict[str, Any]] = None, use_duckdb: Optional[bool] = None, @@ -329,7 +344,7 @@ async def search_to( datetime: Optional[str] = None, include: Optional[str | list[str]] = None, exclude: Optional[str | list[str]] = None, - sortby: Optional[str | list[str]] = None, + sortby: Optional[str | list[str | dict[str, str]]] = None, filter: Optional[str | dict[str, Any]] = None, query: Optional[dict[str, Any]] = None, format: Optional[str] = None, diff --git a/tests/test_search.py b/tests/test_search.py index 472aae7..f31c8f9 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -11,7 +11,6 @@ async def test_search() -> None: "https://landsatlook.usgs.gov/stac-server", collections="landsat-c2l2-sr", intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, - sortby="-properties.datetime", max_items=1, ) assert len(items) == 1 @@ -23,7 +22,6 @@ async def test_search_to(tmp_path: Path) -> None: "https://landsatlook.usgs.gov/stac-server", collections="landsat-c2l2-sr", intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, - sortby="-properties.datetime", max_items=1, ) with open(tmp_path / "out.json") as f: @@ -37,7 +35,6 @@ async def test_search_to_geoparquet(tmp_path: Path) -> None: "https://landsatlook.usgs.gov/stac-server", collections="landsat-c2l2-sr", intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, - sortby="-properties.datetime", max_items=1, ) assert count == 1 @@ -49,3 +46,16 @@ async def test_search_to_geoparquet(tmp_path: Path) -> None: async def test_search_geoparquet(data: Path) -> None: items = await stacrs.search(str(data / "extended-item.parquet")) assert len(items) == 1 + + +async def test_sortby_list_of_dict() -> None: + items = await stacrs.search( + "https://landsatlook.usgs.gov/stac-server", + collections="landsat-c2l2-sr", + intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, + sortby=[ + {"field": "properties.datetime", "direction": "asc"}, + ], + max_items=1, + ) + assert len(items) == 1