Skip to content

Commit 5bd94cf

Browse files
committed
refactor: simplify extension handling in duckdb
1 parent 7274252 commit 5bd94cf

File tree

8 files changed

+90
-26
lines changed

8 files changed

+90
-26
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,5 @@ cython_debug/
164164
#.idea/
165165
.vscode/
166166
docs/generated
167+
168+
tests/duckdb-extensions

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ duckdb-bundled = ["stac-duckdb/bundled"]
1313

1414
[dependencies]
1515
clap = "4.5.31"
16+
duckdb = { version = "1.2.2", features = ["serde_json"] }
1617
geoarrow-array = { git = "https://github.com/geoarrow/geoarrow-rs/", rev = "17bf33e4cf78b060afa08ca9560dc4efd73c2c76" }
1718
geojson = "0.24.1"
1819
pyo3 = { version = "0.24.1", features = ["extension-module"] }

python/rustac/rustac.pyi

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The power of Rust for the Python STAC ecosystem."""
22

3+
from pathlib import Path
34
from typing import Any, AsyncIterator, Literal, Optional, Tuple
45

56
import arro3.core
@@ -15,28 +16,28 @@ class DuckdbClient:
1516
def __init__(
1617
self,
1718
*,
18-
use_s3_credential_chain: bool = True,
19-
use_azure_credential_chain: bool = True,
20-
use_httpfs: bool = True,
19+
extension_directory: Path | None = None,
20+
extensions: list[str] | None = None,
21+
install_spatial: bool = True,
2122
use_hive_partitioning: bool = False,
22-
install_extensions: bool = True,
23-
custom_extension_repository: str | None = None,
24-
extension_directory: str | None = None,
2523
) -> None:
2624
"""Creates a new duckdb client.
2725
2826
Args:
29-
use_s3_credential_chain: If true, configures DuckDB to correctly
30-
handle s3:// urls.
31-
use_azure_credential_chain: If true, configures DuckDB to correctly
32-
handle azure urls.
33-
use_httpfs: If true, configures DuckDB to correctly handle https
34-
urls.
35-
use_hive_partitioning: If true, enables queries on hive partitioned
36-
geoparquet files.
37-
install_extensions: If true, installs extensions before loading them.
38-
custom_extension_repository: A custom extension repository to use.
3927
extension_directory: A non-standard extension directory to use.
28+
extensions: A list of extensions to LOAD on client initialization.
29+
install_spatial: Whether to install the spatial extension on client initialization.
30+
use_hive_partitioning: Whether to use hive partitioning for geoparquet queries.
31+
"""
32+
33+
def execute(self, sql: str, params: list[str] | None = None) -> int:
34+
"""Execute an SQL command.
35+
36+
This can be useful for configuring AWS credentials, for example.
37+
38+
Args:
39+
sql: The SQL to execute
40+
params: The parameters to pass in to the execution
4041
"""
4142

4243
def search(

scripts/test

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@
22

33
set -e
44

5-
uv run maturin dev --uv -E arrow
65
uv run pytest "$@"
76
uv run rustac translate spec-examples/v1.1.0/simple-item.json /dev/null

src/duckdb.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::{
22
search::{PySortby, StringOrDict, StringOrList},
33
Result,
44
};
5+
use duckdb::Connection;
56
use pyo3::{
67
exceptions::PyException,
78
prelude::*,
@@ -18,13 +19,40 @@ pub struct DuckdbClient(Mutex<Client>);
1819
#[pymethods]
1920
impl DuckdbClient {
2021
#[new]
21-
#[pyo3(signature = (*, use_hive_partitioning=false, extension_directory=None, install_extensions=true))]
22+
#[pyo3(signature = (*, extension_directory=None, extensions=Vec::new(), install_spatial=true, use_hive_partitioning=false))]
2223
fn new(
23-
use_hive_partitioning: bool,
2424
extension_directory: Option<PathBuf>,
25-
install_extensions: bool,
25+
extensions: Vec<String>,
26+
install_spatial: bool,
27+
use_hive_partitioning: bool,
2628
) -> Result<DuckdbClient> {
27-
todo!()
29+
let connection = Connection::open_in_memory()?;
30+
if let Some(extension_directory) = extension_directory {
31+
connection.execute(
32+
"SET extension_directory = ?",
33+
[extension_directory.to_string_lossy()],
34+
)?;
35+
}
36+
if install_spatial {
37+
connection.execute("INSTALL spatial", [])?;
38+
}
39+
for extension in extensions {
40+
connection.execute(&format!("LOAD '{}'", extension), [])?;
41+
}
42+
connection.execute("LOAD spatial", [])?;
43+
let mut client = Client::from(connection);
44+
client.use_hive_partitioning = use_hive_partitioning;
45+
Ok(DuckdbClient(Mutex::new(client)))
46+
}
47+
48+
#[pyo3(signature = (sql, params = Vec::new()))]
49+
fn execute<'py>(&self, sql: String, params: Vec<String>) -> Result<usize> {
50+
let client = self
51+
.0
52+
.lock()
53+
.map_err(|err| PyException::new_err(err.to_string()))?;
54+
let count = client.execute(&sql, duckdb::params_from_iter(params))?;
55+
Ok(count)
2856
}
2957

3058
#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, **kwargs))]

src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ create_exception!(rustac, RustacError, PyException);
99

1010
#[derive(Debug, Error)]
1111
pub enum Error {
12+
#[error(transparent)]
13+
Duckdb(#[from] duckdb::Error),
14+
1215
#[error(transparent)]
1316
Geojson(#[from] geojson::Error),
1417

tests/test_duckdb.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from pathlib import Path
22

33
import pytest
4-
from geopandas import GeoDataFrame
5-
64
import rustac
5+
from geopandas import GeoDataFrame
76
from rustac import DuckdbClient, RustacError
87

98

@@ -12,6 +11,11 @@ def client() -> DuckdbClient:
1211
return DuckdbClient()
1312

1413

14+
@pytest.fixture
15+
def extension_directory() -> Path:
16+
return Path(__file__).parent / "duckdb-extensions"
17+
18+
1519
def test_search(client: DuckdbClient) -> None:
1620
items = client.search("data/extended-item.parquet")
1721
assert len(items) == 1
@@ -40,13 +44,37 @@ def test_search_to_arrow(client: DuckdbClient) -> None:
4044
assert len(item_collection["features"]) == 100
4145

4246

43-
def test_custom_extension_directory() -> None:
44-
extension_directory = Path(__file__).parent / "duckdb-extensions"
47+
def test_custom_extension_directory(extension_directory: Path) -> None:
4548
client = DuckdbClient(extension_directory=extension_directory)
4649
# Search to ensure we trigger everything
4750
client.search("data/100-sentinel-2-items.parquet")
4851

4952

5053
def test_no_install(tmp_path: Path) -> None:
5154
with pytest.raises(RustacError):
52-
DuckdbClient(extension_directory=tmp_path, install_extensions=False)
55+
DuckdbClient(extension_directory=tmp_path, install_spatial=False)
56+
57+
58+
def test_extensions(extension_directory: Path, tmp_path: Path) -> None:
59+
# Ensure we've fetched the extension
60+
DuckdbClient(extension_directory=extension_directory)
61+
62+
extension = next(extension_directory.glob("**/spatial.duckdb_extension"))
63+
client = DuckdbClient(
64+
extensions=[str(extension)], extension_directory=tmp_path, install_spatial=False
65+
)
66+
client.search("data/100-sentinel-2-items.parquet")
67+
68+
69+
def test_execute(client: DuckdbClient, extension_directory: Path) -> None:
70+
# Just a smoke test
71+
client.execute("SET extension_directory = ?", [str(extension_directory)])
72+
73+
74+
def test_load_spatial() -> None:
75+
DuckdbClient(extensions=["spatial"])
76+
77+
78+
@pytest.mark.skip("slow")
79+
def test_aws_credential_chain(client: DuckdbClient) -> None:
80+
client.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)")

0 commit comments

Comments
 (0)