diff --git a/Cargo.lock b/Cargo.lock index 875a54f..b9ff64b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1094,9 +1094,9 @@ dependencies = [ [[package]] name = "geo-types" -version = "0.7.14" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6f47c611187777bbca61ea7aba780213f5f3441fd36294ab333e96cfa791b65" +checksum = "3bd1157f0f936bf0cd68dec91e8f7c311afe60295574d62b70d4861a1bfdf2d9" dependencies = [ "approx", "num-traits", @@ -1868,9 +1868,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "lz4_flex" @@ -2964,9 +2964,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "indexmap", "itoa", @@ -3104,8 +3104,8 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stac" -version = "0.11.1" -source = "git+https://github.com/stac-utils/stac-rs#0476743b841c6e413b77a5281a1515172fbcd804" +version = "0.12.0" +source = "git+https://github.com/stac-utils/stac-rs#eb404cf095247e716609a1b44efd47b08fd9f6c2" dependencies = [ "arrow-array", "arrow-cast", @@ -3132,8 +3132,8 @@ dependencies = [ [[package]] name = "stac-api" -version = "0.7.0" -source = "git+https://github.com/stac-utils/stac-rs#0476743b841c6e413b77a5281a1515172fbcd804" +version = "0.7.1" +source = "git+https://github.com/stac-utils/stac-rs#eb404cf095247e716609a1b44efd47b08fd9f6c2" dependencies = [ "async-stream", "chrono", @@ -3158,7 +3158,7 @@ dependencies = [ [[package]] name = "stac-derive" version = "0.2.0" -source = "git+https://github.com/stac-utils/stac-rs#0476743b841c6e413b77a5281a1515172fbcd804" +source = "git+https://github.com/stac-utils/stac-rs#eb404cf095247e716609a1b44efd47b08fd9f6c2" dependencies = [ "quote", "syn 2.0.95", @@ -3166,8 +3166,8 @@ dependencies = [ [[package]] name = "stac-duckdb" -version = "0.1.0" -source = "git+https://github.com/stac-utils/stac-rs#0476743b841c6e413b77a5281a1515172fbcd804" +version = "0.1.1" +source = "git+https://github.com/stac-utils/stac-rs#eb404cf095247e716609a1b44efd47b08fd9f6c2" dependencies = [ "arrow", "chrono", @@ -3198,7 +3198,6 @@ dependencies = [ "stac-api", "stac-duckdb", "thiserror 2.0.11", - "tokio", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ab2e901..d7f7e97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,15 +19,14 @@ pyo3-async-runtimes = { version = "0.23.0", features = [ pyo3-log = "0.12.1" pythonize = "0.23.0" serde = "1.0.217" -serde_json = "1.0.135" -stac = { version = "0.11.1", features = [ +serde_json = "1.0.138" +stac = { version = "0.12.0", features = [ "geoparquet-compression", "object-store-all", ], git = "https://github.com/stac-utils/stac-rs" } -stac-api = { version = "0.7.0", features = [ +stac-api = { version = "0.7.1", features = [ "client", "python", ], git = "https://github.com/stac-utils/stac-rs" } stac-duckdb = { version = "0.1.0", git = "https://github.com/stac-utils/stac-rs" } thiserror = "2.0.11" -tokio = { version = "1.43.0", features = ["rt"] } diff --git a/src/search.rs b/src/search.rs index cf50e4a..339e2f9 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,15 +1,10 @@ -use crate::Error; -use pyo3::{ - prelude::*, - types::{PyDict, PyList}, -}; +use crate::{Error, Json, Result}; +use pyo3::{prelude::*, types::PyDict}; use stac::Format; use stac_api::{ python::{StringOrDict, StringOrList}, - BlockingClient, Item, ItemCollection, + Search, }; -use stac_duckdb::Client; -use tokio::runtime::Builder; #[pyfunction] #[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, use_duckdb=None, **kwargs))] @@ -31,13 +26,11 @@ pub fn search<'py>( query: Option>, use_duckdb: Option, kwargs: Option>, -) -> PyResult> { - let items = search_items( - href, +) -> PyResult> { + let search = stac_api::python::search( intersects, ids, collections, - max_items, limit, bbox, datetime, @@ -46,18 +39,28 @@ pub fn search<'py>( sortby, filter, query, - use_duckdb, kwargs, )?; - pythonize::pythonize(py, &items) - .map_err(PyErr::from) - .and_then(|v| v.extract()) + if use_duckdb + .unwrap_or_else(|| matches!(Format::infer_from_href(&href), Some(Format::Geoparquet(_)))) + { + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value = search_duckdb(href, search, max_items)?; + Ok(Json(value.items)) + }) + } else { + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value = search_api(href, search, max_items).await?; + Ok(Json(value.items)) + }) + } } #[pyfunction] #[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, options=None, use_duckdb=None, **kwargs))] #[allow(clippy::too_many_arguments)] pub fn search_to<'py>( + py: Python<'py>, outfile: String, href: String, intersects: Option, @@ -76,13 +79,11 @@ pub fn search_to<'py>( options: Option>, use_duckdb: Option, kwargs: Option>, -) -> PyResult { - let items = search_items( - href, +) -> PyResult> { + let search = stac_api::python::search( intersects, ids, collections, - max_items, limit, bbox, datetime, @@ -91,7 +92,6 @@ pub fn search_to<'py>( sortby, filter, query, - use_duckdb, kwargs, )?; let format = format @@ -100,77 +100,53 @@ pub fn search_to<'py>( .map_err(Error::from)? .or_else(|| Format::infer_from_href(&outfile)) .unwrap_or_default(); - let item_collection = ItemCollection::from(items); - let count = item_collection.items.len(); - Builder::new_current_thread() - .build()? - .block_on(format.put_opts( - outfile, - serde_json::to_value(item_collection).map_err(Error::from)?, - options.unwrap_or_default(), - )) - .map_err(Error::from)?; - Ok(count) -} - -#[allow(clippy::too_many_arguments)] -fn search_items<'py>( - href: String, - intersects: Option, - ids: Option, - collections: Option, - max_items: Option, - limit: Option, - bbox: Option>, - datetime: Option, - include: Option, - exclude: Option, - sortby: Option, - filter: Option, - query: Option>, - use_duckdb: Option, - kwargs: Option>, -) -> PyResult> { - let mut search = stac_api::python::search( - intersects, - ids, - collections, - limit, - bbox, - datetime, - include, - exclude, - sortby, - filter, - query, - kwargs, - )?; if use_duckdb .unwrap_or_else(|| matches!(Format::infer_from_href(&href), Some(Format::Geoparquet(_)))) { - if let Some(max_items) = max_items { - search.items.limit = Some(max_items.try_into()?); - } - let client = Client::new().map_err(Error::from)?; - client - .search_to_json(&href, search) - .map(|item_collection| item_collection.items) - .map_err(Error::from) - .map_err(PyErr::from) + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value = search_duckdb(href, search, max_items)?; + let count = value.items.len(); + let _ = format + .put_opts( + outfile, + serde_json::to_value(value).map_err(Error::from)?, + options.unwrap_or_default(), + ) + .await + .map_err(Error::from)?; + Ok(count) + }) } else { - let client = BlockingClient::new(&href).map_err(Error::from)?; - let items = client.search(search).map_err(Error::from)?; - if let Some(max_items) = max_items { - items - .take(max_items) - .collect::>() - .map_err(Error::from) - .map_err(PyErr::from) - } else { - items - .collect::>() - .map_err(Error::from) - .map_err(PyErr::from) - } + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value = search_api(href, search, max_items).await?; + let count = value.items.len(); + let _ = format + .put_opts( + outfile, + serde_json::to_value(value).map_err(Error::from)?, + options.unwrap_or_default(), + ) + .await + .map_err(Error::from)?; + Ok(count) + }) } } + +fn search_duckdb( + href: String, + search: Search, + max_items: Option, +) -> Result { + let value = stac_duckdb::search(&href, search, max_items)?; + Ok(value) +} + +async fn search_api( + href: String, + search: Search, + max_items: Option, +) -> Result { + let value = stac_api::client::search(&href, search, max_items).await?; + Ok(value) +} diff --git a/stacrs.pyi b/stacrs.pyi index 1563061..f29d414 100644 --- a/stacrs.pyi +++ b/stacrs.pyi @@ -142,7 +142,7 @@ async def read( >>> item = await stacrs.read("item.json") """ -def search( +async def search( href: str, *, intersects: Optional[str | dict[str, Any]] = None, @@ -201,7 +201,7 @@ def search( list[dict[str, Any]]: A list of the returned STAC items. Examples: - >>> items = stacrs.search( + >>> items = await stacrs.search( ... "https://landsatlook.usgs.gov/stac-server", ... collections=["landsat-c2l2-sr"], ... intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, @@ -210,7 +210,7 @@ def search( ... ) """ -def search_to( +async def search_to( outfile: str, href: str, *, @@ -272,10 +272,10 @@ def search_to( to None. Returns: - list[dict[str, Any]]: A list of the returned STAC items. + int: The number of items written Examples: - >>> items = stacrs.search_to("out.parquet", + >>> count = await stacrs.search_to("out.parquet", ... "https://landsatlook.usgs.gov/stac-server", ... collections=["landsat-c2l2-sr"], ... intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, diff --git a/tests/test_search.py b/tests/test_search.py index ba2f07a..472aae7 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -6,8 +6,8 @@ import stacrs -def test_search() -> None: - items = stacrs.search( +async def test_search() -> None: + items = await stacrs.search( "https://landsatlook.usgs.gov/stac-server", collections="landsat-c2l2-sr", intersects={"type": "Point", "coordinates": [-105.119, 40.173]}, @@ -17,8 +17,8 @@ def test_search() -> None: assert len(items) == 1 -def test_search_to(tmp_path: Path) -> None: - stacrs.search_to( +async def test_search_to(tmp_path: Path) -> None: + await stacrs.search_to( str(tmp_path / "out.json"), "https://landsatlook.usgs.gov/stac-server", collections="landsat-c2l2-sr", @@ -31,8 +31,8 @@ def test_search_to(tmp_path: Path) -> None: assert len(data["features"]) == 1 -def test_search_to_geoparquet(tmp_path: Path) -> None: - count = stacrs.search_to( +async def test_search_to_geoparquet(tmp_path: Path) -> None: + count = await stacrs.search_to( str(tmp_path / "out.parquet"), "https://landsatlook.usgs.gov/stac-server", collections="landsat-c2l2-sr", @@ -46,6 +46,6 @@ def test_search_to_geoparquet(tmp_path: Path) -> None: assert len(items) == 1 -def test_search_geoparquet(data: Path) -> None: - items = stacrs.search(str(data / "extended-item.parquet")) +async def test_search_geoparquet(data: Path) -> None: + items = await stacrs.search(str(data / "extended-item.parquet")) assert len(items) == 1 diff --git a/uv.lock b/uv.lock index 18c7dfb..4f744d3 100644 --- a/uv.lock +++ b/uv.lock @@ -2040,7 +2040,6 @@ wheels = [ [[package]] name = "stacrs" -version = "0.4.0" source = { editable = "." } [package.dev-dependencies]