diff --git a/.gitignore b/.gitignore index edd30e7..943c576 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,5 @@ cython_debug/ #.idea/ .vscode/ docs/generated + +tests/duckdb-extensions diff --git a/Cargo.lock b/Cargo.lock index f32ff26..4fa5906 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1018,6 +1018,7 @@ dependencies = [ "memchr", "num-integer", "rust_decimal", + "serde_json", "smallvec", "strum", ] @@ -2167,9 +2168,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.11" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72" [[package]] name = "libredox" @@ -2670,7 +2671,7 @@ dependencies = [ [[package]] name = "pgstac" version = "0.3.0" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "serde", "serde_json", @@ -3000,9 +3001,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.10" +version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc" +checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b" dependencies = [ "bytes", "getrandom 0.3.2", @@ -3375,7 +3376,7 @@ dependencies = [ [[package]] name = "rustac" version = "0.5.3" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "anyhow", "axum", @@ -3396,6 +3397,7 @@ version = "0.6.0" dependencies = [ "cargo-lock", "clap", + "duckdb", "geoarrow-array", "geojson", "pyo3", @@ -3777,7 +3779,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stac" version = "0.12.0" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "arrow-array", "arrow-cast", @@ -3812,7 +3814,7 @@ dependencies = [ [[package]] name = "stac-api" version = "0.7.1" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "async-stream", "chrono", @@ -3837,7 +3839,7 @@ dependencies = [ [[package]] name = "stac-derive" version = "0.2.0" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "quote", "syn 2.0.100", @@ -3846,7 +3848,7 @@ dependencies = [ [[package]] name = "stac-duckdb" version = "0.1.1" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "arrow-array", "chrono", @@ -3865,7 +3867,7 @@ dependencies = [ [[package]] name = "stac-server" version = "0.3.4" -source = "git+https://github.com/stac-utils/rustac?branch=main#6c284a1999a3fdbadd5f311a9d2d450249b09758" +source = "git+https://github.com/stac-utils/rustac?branch=main#8f24498fd3d227067dab8738e1a792e57b7d7024" dependencies = [ "axum", "bb8", diff --git a/Cargo.toml b/Cargo.toml index 76c0008..cb719d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ duckdb-bundled = ["stac-duckdb/bundled"] [dependencies] clap = "4.5.31" +duckdb = { version = "1.2.2", features = ["serde_json"] } geoarrow-array = { git = "https://github.com/geoarrow/geoarrow-rs/", rev = "17bf33e4cf78b060afa08ca9560dc4efd73c2c76" } geojson = "0.24.1" pyo3 = { version = "0.24.1", features = ["extension-module"] } diff --git a/python/rustac/rustac.pyi b/python/rustac/rustac.pyi index b1d5df4..5871ca5 100644 --- a/python/rustac/rustac.pyi +++ b/python/rustac/rustac.pyi @@ -1,5 +1,6 @@ """The power of Rust for the Python STAC ecosystem.""" +from pathlib import Path from typing import Any, AsyncIterator, Literal, Optional, Tuple import arro3.core @@ -15,28 +16,28 @@ class DuckdbClient: def __init__( self, *, - use_s3_credential_chain: bool = True, - use_azure_credential_chain: bool = True, - use_httpfs: bool = True, + extension_directory: Path | None = None, + extensions: list[str] | None = None, + install_spatial: bool = True, use_hive_partitioning: bool = False, - install_extensions: bool = True, - custom_extension_repository: str | None = None, - extension_directory: str | None = None, ) -> None: """Creates a new duckdb client. Args: - use_s3_credential_chain: If true, configures DuckDB to correctly - handle s3:// urls. - use_azure_credential_chain: If true, configures DuckDB to correctly - handle azure urls. - use_httpfs: If true, configures DuckDB to correctly handle https - urls. - use_hive_partitioning: If true, enables queries on hive partitioned - geoparquet files. - install_extensions: If true, installs extensions before loading them. - custom_extension_repository: A custom extension repository to use. extension_directory: A non-standard extension directory to use. + extensions: A list of extensions to LOAD on client initialization. + install_spatial: Whether to install the spatial extension on client initialization. + use_hive_partitioning: Whether to use hive partitioning for geoparquet queries. + """ + + def execute(self, sql: str, params: list[str] | None = None) -> int: + """Execute an SQL command. + + This can be useful for configuring AWS credentials, for example. + + Args: + sql: The SQL to execute + params: The parameters to pass in to the execution """ def search( diff --git a/scripts/test b/scripts/test index d5d4d2a..3676ba3 100755 --- a/scripts/test +++ b/scripts/test @@ -2,6 +2,5 @@ set -e -uv run maturin dev --uv -E arrow uv run pytest "$@" uv run rustac translate spec-examples/v1.1.0/simple-item.json /dev/null diff --git a/src/duckdb.rs b/src/duckdb.rs index 558af51..39f0685 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -2,6 +2,7 @@ use crate::{ search::{PySortby, StringOrDict, StringOrList}, Result, }; +use duckdb::Connection; use pyo3::{ exceptions::PyException, prelude::*, @@ -9,8 +10,8 @@ use pyo3::{ IntoPyObjectExt, }; use pyo3_arrow::PyTable; -use stac_duckdb::{Client, Config}; -use std::sync::Mutex; +use stac_duckdb::Client; +use std::{path::PathBuf, sync::Mutex}; #[pyclass(frozen)] pub struct DuckdbClient(Mutex); @@ -18,30 +19,42 @@ pub struct DuckdbClient(Mutex); #[pymethods] impl DuckdbClient { #[new] - #[pyo3(signature = (*, use_s3_credential_chain=false, use_azure_credential_chain=false, use_httpfs=false, use_hive_partitioning=false, install_extensions=true, custom_extension_repository=None, extension_directory=None))] + #[pyo3(signature = (*, extension_directory=None, extensions=Vec::new(), install_spatial=true, use_hive_partitioning=false))] fn new( - use_s3_credential_chain: bool, - use_azure_credential_chain: bool, - use_httpfs: bool, + extension_directory: Option, + extensions: Vec, + install_spatial: bool, use_hive_partitioning: bool, - install_extensions: bool, - custom_extension_repository: Option, - extension_directory: Option, ) -> Result { - let config = Config { - use_s3_credential_chain, - use_azure_credential_chain, - use_httpfs, - use_hive_partitioning, - install_extensions, - custom_extension_repository, - extension_directory, - convert_wkb: true, - }; - let client = Client::with_config(config)?; + let connection = Connection::open_in_memory()?; + if let Some(extension_directory) = extension_directory { + connection.execute( + "SET extension_directory = ?", + [extension_directory.to_string_lossy()], + )?; + } + if install_spatial { + connection.execute("INSTALL spatial", [])?; + } + for extension in extensions { + connection.execute(&format!("LOAD '{}'", extension), [])?; + } + connection.execute("LOAD spatial", [])?; + let mut client = Client::from(connection); + client.use_hive_partitioning = use_hive_partitioning; Ok(DuckdbClient(Mutex::new(client))) } + #[pyo3(signature = (sql, params = Vec::new()))] + fn execute<'py>(&self, sql: String, params: Vec) -> Result { + let client = self + .0 + .lock() + .map_err(|err| PyException::new_err(err.to_string()))?; + let count = client.execute(&sql, duckdb::params_from_iter(params))?; + Ok(count) + } + #[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, **kwargs))] fn search<'py>( &self, @@ -123,10 +136,11 @@ impl DuckdbClient { .0 .lock() .map_err(|err| PyException::new_err(err.to_string()))?; - let convert_wkb = client.config.convert_wkb; - client.config.convert_wkb = false; + // FIXME this is awkward + let convert_wkb = client.convert_wkb; + client.convert_wkb = false; let result = client.search_to_arrow(&href, search); - client.config.convert_wkb = convert_wkb; + client.convert_wkb = convert_wkb; result? }; if record_batches.is_empty() { diff --git a/src/error.rs b/src/error.rs index 27273e7..a7d8ec0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,9 @@ create_exception!(rustac, RustacError, PyException); #[derive(Debug, Error)] pub enum Error { + #[error(transparent)] + Duckdb(#[from] duckdb::Error), + #[error(transparent)] Geojson(#[from] geojson::Error), diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index aec217e..f163ff4 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -1,7 +1,9 @@ +from pathlib import Path + import pytest import rustac from geopandas import GeoDataFrame -from rustac import DuckdbClient +from rustac import DuckdbClient, RustacError @pytest.fixture @@ -9,6 +11,11 @@ def client() -> DuckdbClient: return DuckdbClient() +@pytest.fixture +def extension_directory() -> Path: + return Path(__file__).parent / "duckdb-extensions" + + def test_search(client: DuckdbClient) -> None: items = client.search("data/extended-item.parquet") assert len(items) == 1 @@ -27,11 +34,6 @@ def test_get_collections(client: DuckdbClient) -> None: assert len(collections) == 1 -@pytest.mark.skip("slow") -def test_init_with_config() -> None: - DuckdbClient(use_s3_credential_chain=True, use_hive_partitioning=True) - - def test_search_to_arrow(client: DuckdbClient) -> None: pytest.importorskip("arro3.core") table = client.search_to_arrow("data/100-sentinel-2-items.parquet") @@ -40,3 +42,39 @@ def test_search_to_arrow(client: DuckdbClient) -> None: data_frame_table = data_frame.to_arrow() item_collection = rustac.from_arrow(data_frame_table) assert len(item_collection["features"]) == 100 + + +def test_custom_extension_directory(extension_directory: Path) -> None: + client = DuckdbClient(extension_directory=extension_directory) + # Search to ensure we trigger everything + client.search("data/100-sentinel-2-items.parquet") + + +def test_no_install(tmp_path: Path) -> None: + with pytest.raises(RustacError): + DuckdbClient(extension_directory=tmp_path, install_spatial=False) + + +def test_extensions(extension_directory: Path, tmp_path: Path) -> None: + # Ensure we've fetched the extension + DuckdbClient(extension_directory=extension_directory) + + extension = next(extension_directory.glob("**/spatial.duckdb_extension")) + client = DuckdbClient( + extensions=[str(extension)], extension_directory=tmp_path, install_spatial=False + ) + client.search("data/100-sentinel-2-items.parquet") + + +def test_execute(client: DuckdbClient, extension_directory: Path) -> None: + # Just a smoke test + client.execute("SET extension_directory = ?", [str(extension_directory)]) + + +def test_load_spatial() -> None: + DuckdbClient(extensions=["spatial"]) + + +@pytest.mark.skip("slow") +def test_aws_credential_chain(client: DuckdbClient) -> None: + client.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)")