Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,5 @@ cython_debug/
#.idea/
.vscode/
docs/generated

tests/duckdb-extensions
24 changes: 13 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ duckdb-bundled = ["stac-duckdb/bundled"]

[dependencies]
clap = "4.5.31"
duckdb = { version = "1.2.2", features = ["serde_json"] }
geoarrow-array = { git = "https://github.com/geoarrow/geoarrow-rs/", rev = "17bf33e4cf78b060afa08ca9560dc4efd73c2c76" }
geojson = "0.24.1"
pyo3 = { version = "0.24.1", features = ["extension-module"] }
Expand Down
33 changes: 17 additions & 16 deletions python/rustac/rustac.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The power of Rust for the Python STAC ecosystem."""

from pathlib import Path
from typing import Any, AsyncIterator, Literal, Optional, Tuple

import arro3.core
Expand All @@ -15,28 +16,28 @@ class DuckdbClient:
def __init__(
self,
*,
use_s3_credential_chain: bool = True,
use_azure_credential_chain: bool = True,
use_httpfs: bool = True,
extension_directory: Path | None = None,
extensions: list[str] | None = None,
install_spatial: bool = True,
use_hive_partitioning: bool = False,
install_extensions: bool = True,
custom_extension_repository: str | None = None,
extension_directory: str | None = None,
) -> None:
"""Creates a new duckdb client.

Args:
use_s3_credential_chain: If true, configures DuckDB to correctly
handle s3:// urls.
use_azure_credential_chain: If true, configures DuckDB to correctly
handle azure urls.
use_httpfs: If true, configures DuckDB to correctly handle https
urls.
use_hive_partitioning: If true, enables queries on hive partitioned
geoparquet files.
install_extensions: If true, installs extensions before loading them.
custom_extension_repository: A custom extension repository to use.
extension_directory: A non-standard extension directory to use.
extensions: A list of extensions to LOAD on client initialization.
install_spatial: Whether to install the spatial extension on client initialization.
use_hive_partitioning: Whether to use hive partitioning for geoparquet queries.
"""

def execute(self, sql: str, params: list[str] | None = None) -> int:
"""Execute an SQL command.

This can be useful for configuring AWS credentials, for example.

Args:
sql: The SQL to execute
params: The parameters to pass in to the execution
"""

def search(
Expand Down
1 change: 0 additions & 1 deletion scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

set -e

uv run maturin dev --uv -E arrow
uv run pytest "$@"
uv run rustac translate spec-examples/v1.1.0/simple-item.json /dev/null
60 changes: 37 additions & 23 deletions src/duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,59 @@ use crate::{
search::{PySortby, StringOrDict, StringOrList},
Result,
};
use duckdb::Connection;
use pyo3::{
exceptions::PyException,
prelude::*,
types::{PyDict, PyList},
IntoPyObjectExt,
};
use pyo3_arrow::PyTable;
use stac_duckdb::{Client, Config};
use std::sync::Mutex;
use stac_duckdb::Client;
use std::{path::PathBuf, sync::Mutex};

#[pyclass(frozen)]
pub struct DuckdbClient(Mutex<Client>);

#[pymethods]
impl DuckdbClient {
#[new]
#[pyo3(signature = (*, use_s3_credential_chain=false, use_azure_credential_chain=false, use_httpfs=false, use_hive_partitioning=false, install_extensions=true, custom_extension_repository=None, extension_directory=None))]
#[pyo3(signature = (*, extension_directory=None, extensions=Vec::new(), install_spatial=true, use_hive_partitioning=false))]
fn new(
use_s3_credential_chain: bool,
use_azure_credential_chain: bool,
use_httpfs: bool,
extension_directory: Option<PathBuf>,
extensions: Vec<String>,
install_spatial: bool,
use_hive_partitioning: bool,
install_extensions: bool,
custom_extension_repository: Option<String>,
extension_directory: Option<String>,
) -> Result<DuckdbClient> {
let config = Config {
use_s3_credential_chain,
use_azure_credential_chain,
use_httpfs,
use_hive_partitioning,
install_extensions,
custom_extension_repository,
extension_directory,
convert_wkb: true,
};
let client = Client::with_config(config)?;
let connection = Connection::open_in_memory()?;
if let Some(extension_directory) = extension_directory {
connection.execute(
"SET extension_directory = ?",
[extension_directory.to_string_lossy()],
)?;
}
if install_spatial {
connection.execute("INSTALL spatial", [])?;
}
for extension in extensions {
connection.execute(&format!("LOAD '{}'", extension), [])?;
}
connection.execute("LOAD spatial", [])?;
let mut client = Client::from(connection);
client.use_hive_partitioning = use_hive_partitioning;
Ok(DuckdbClient(Mutex::new(client)))
}

#[pyo3(signature = (sql, params = Vec::new()))]
fn execute<'py>(&self, sql: String, params: Vec<String>) -> Result<usize> {
let client = self
.0
.lock()
.map_err(|err| PyException::new_err(err.to_string()))?;
let count = client.execute(&sql, duckdb::params_from_iter(params))?;
Ok(count)
}

#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, **kwargs))]
fn search<'py>(
&self,
Expand Down Expand Up @@ -123,10 +136,11 @@ impl DuckdbClient {
.0
.lock()
.map_err(|err| PyException::new_err(err.to_string()))?;
let convert_wkb = client.config.convert_wkb;
client.config.convert_wkb = false;
// FIXME this is awkward
let convert_wkb = client.convert_wkb;
client.convert_wkb = false;
let result = client.search_to_arrow(&href, search);
client.config.convert_wkb = convert_wkb;
client.convert_wkb = convert_wkb;
result?
};
if record_batches.is_empty() {
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ create_exception!(rustac, RustacError, PyException);

#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
Duckdb(#[from] duckdb::Error),

#[error(transparent)]
Geojson(#[from] geojson::Error),

Expand Down
50 changes: 44 additions & 6 deletions tests/test_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from pathlib import Path

import pytest
import rustac
from geopandas import GeoDataFrame
from rustac import DuckdbClient
from rustac import DuckdbClient, RustacError


@pytest.fixture
def client() -> DuckdbClient:
return DuckdbClient()


@pytest.fixture
def extension_directory() -> Path:
return Path(__file__).parent / "duckdb-extensions"


def test_search(client: DuckdbClient) -> None:
items = client.search("data/extended-item.parquet")
assert len(items) == 1
Expand All @@ -27,11 +34,6 @@ def test_get_collections(client: DuckdbClient) -> None:
assert len(collections) == 1


@pytest.mark.skip("slow")
def test_init_with_config() -> None:
DuckdbClient(use_s3_credential_chain=True, use_hive_partitioning=True)


def test_search_to_arrow(client: DuckdbClient) -> None:
pytest.importorskip("arro3.core")
table = client.search_to_arrow("data/100-sentinel-2-items.parquet")
Expand All @@ -40,3 +42,39 @@ def test_search_to_arrow(client: DuckdbClient) -> None:
data_frame_table = data_frame.to_arrow()
item_collection = rustac.from_arrow(data_frame_table)
assert len(item_collection["features"]) == 100


def test_custom_extension_directory(extension_directory: Path) -> None:
client = DuckdbClient(extension_directory=extension_directory)
# Search to ensure we trigger everything
client.search("data/100-sentinel-2-items.parquet")


def test_no_install(tmp_path: Path) -> None:
with pytest.raises(RustacError):
DuckdbClient(extension_directory=tmp_path, install_spatial=False)


def test_extensions(extension_directory: Path, tmp_path: Path) -> None:
# Ensure we've fetched the extension
DuckdbClient(extension_directory=extension_directory)

extension = next(extension_directory.glob("**/spatial.duckdb_extension"))
client = DuckdbClient(
extensions=[str(extension)], extension_directory=tmp_path, install_spatial=False
)
client.search("data/100-sentinel-2-items.parquet")


def test_execute(client: DuckdbClient, extension_directory: Path) -> None:
# Just a smoke test
client.execute("SET extension_directory = ?", [str(extension_directory)])


def test_load_spatial() -> None:
DuckdbClient(extensions=["spatial"])


@pytest.mark.skip("slow")
def test_aws_credential_chain(client: DuckdbClient) -> None:
client.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)")
Loading